initial version

Signed-off-by: leonwanghui <leon.wanghui@huawei.com>
This commit is contained in:
zhunaipan 2020-03-27 15:34:32 +08:00 committed by leonwanghui
commit 576a73b42a
3116 changed files with 480920 additions and 0 deletions

152
.clang-format Normal file
View File

@ -0,0 +1,152 @@
---
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlines: Left
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
SplitEmptyFunction: true
SplitEmptyRecord: true
SplitEmptyNamespace: true
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakConstructorInitializers: BeforeColon
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
DerivePointerAlignment: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros:
# - foreach
- Q_FOREACH
- BOOST_FOREACH
IncludeBlocks: Preserve
IncludeCategories:
- Regex: '^<ext/.*\.h>'
Priority: 2
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Never
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
BasedOnStyle: google
- Language: TextProto
Delimiters:
- pb
- PB
- proto
- PROTO
EnclosingFunctions:
- EqualsProto
- EquivToProto
- PARSE_PARTIAL_TEXT_PROTO
- PARSE_TEST_PROTO
- PARSE_TEXT_PROTO
- ParseTextOrDie
- ParseTextProtoOrDie
CanonicalDelimiter: ''
BasedOnStyle: google
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
TabWidth: 2
UseTab: Never
SortIncludes: false
...

73
.gitignore vendored Normal file
View File

@ -0,0 +1,73 @@
# MindSpore
build/
mindspore/lib
output
*.ir
# Cmake files
CMakeFiles/
cmake_install.cmake
CMakeCache.txt
Makefile
cmake-build-debug
# Dynamic libraries
*.so
*.so.*
*.dylib
# Static libraries
*.la
*.lai
*.a
*.lib
# Protocol buffers
*_pb2.py
*.pb.h
*.pb.cc
# Object files
*.o
# Editor
.vscode
.idea/
# Cquery
.cquery_cached_index/
compile_commands.json
# Ctags and cscope
tags
TAGS
CTAGS
GTAGS
GRTAGS
GSYMS
GPATH
cscope.*
# Python files
*__pycache__*
.pytest_cache
# Mac files
*.DS_Store
# Test results
test_temp_summary_event_file/
*.dot
*.dat
*.svg
*.perf
*.info
*.ckpt
*.shp
*.pkl
.clangd
mindspore/version.py
mindspore/default_config.py
mindspore/.commit_id
onnx.proto
mindspore/ccsrc/onnx.proto

15
.gitmodules vendored Normal file
View File

@ -0,0 +1,15 @@
[submodule "third_party/flatbuffers"]
path = third_party/flatbuffers
url = https://github.com/google/flatbuffers.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/incubator-tvm"]
path = third_party/incubator-tvm
url = https://github.com/apache/incubator-tvm.git
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf.git
[submodule "graphengine"]
path = graphengine
url = https://gitee.com/mindspore/graphengine.git

55
CMakeLists.txt Normal file
View File

@ -0,0 +1,55 @@
cmake_minimum_required(VERSION 3.14)
project (MindSpore)
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)=' -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(PYBIND11_CPP_STANDARD -std=c++17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPTION_CXX_FLAGS}")
find_package(Threads)
include(${CMAKE_SOURCE_DIR}/cmake/mind_expression.cmake)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/flatbuffers/include/flatbuffers)
include(${CMAKE_SOURCE_DIR}/cmake/dependency_utils.cmake)
find_package(Python3 COMPONENTS Interpreter Development)
if(Python3_FOUND)
set(PYTHON_INCLUDE_DIRS "${Python3_INCLUDE_DIRS}")
set(PYTHON_LIBRARIES "${Python3_LIBRARIES}")
else()
find_python_package(py_inc py_lib)
set(PYTHON_INCLUDE_DIRS "${py_inc}")
set(PYTHON_LIBRARIES "${py_lib}")
endif()
message("PYTHON_INCLUDE_DIRS = ${PYTHON_INCLUDE_DIRS}")
message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
include_directories(${PYTHON_INCLUDE_DIRS})
set(MS_CCSRC_PATH ${CMAKE_SOURCE_DIR}/mindspore/ccsrc)
set(MS_CCSRC_BUILD_PATH ${BUILD_PATH}/mindspore/mindspore/ccsrc)
if (ENABLE_GE)
link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib)
else()
include(${CMAKE_SOURCE_DIR}/cmake/dependency_graphengine.cmake)
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
add_subdirectory(mindspore/ccsrc)
if (ENABLE_TESTCASES)
add_subdirectory(tests)
endif()

115
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,115 @@
# MindSpore contributing guidelines
<!-- TOC -->
- [MindSpore contributing guidelines](#mindspore-contributing-guidelines)
- [Contributor License Agreement](#contributor-license-agreement)
- [Getting Started](#getting-started)
- [Contribution workflow](#contribution-workflow)
- [Code style](#code-style)
- [Fork-Pull development model](#fork-pull-development-model)
- [Report issues](#report-issues)
- [Propose PRs](#propose-prs)
<!-- /TOC -->
## Contributor License Agreement
It's required to sign CLA before your first code submission to MindSpore community.
For individual contributor, please refer to [ICLA online document](https://www.mindspore.cn/icla) for the detailed information.
## Getting Started
- Fork the repository on [Github](https://github.com/mindspore-ai/mindspore) or [Gitee](https://gitee.com/mindspore/mindspore).
- Read the [README.md](README.md) and [install page](https://www.mindspore.cn/install/en) for project information and build instructions.
## Contribution Workflow
### Code style
Please follow this style to make MindSpore easy to review, maintain and develop.
* Coding guidelines
The *Python* coding style suggested by [Python PEP 8 Coding Style](https://pep8.org/) and *C++* coding style suggested by [Google C++ Coding Guidelines](http://google.github.io/styleguide/cppguide.html) are used in MindSpore community.
* Unittest guidelines
The *Python* unittest style suggested by [pytest](http://www.pytest.org/en/latest/) and *C++* unittest style suggested by [Googletest Primer](https://github.com/google/googletest/blob/master/googletest/docs/primer.md) are used in MindSpore community.
### Fork-Pull development model
* Fork MindSpore repository
Before submitting code to MindSpore project, please make sure that this project have been forked to your own repository. It means that there will be parallel development between MindSpore repository and your own repository, so be careful to avoid the inconsistency between them.
* Clone the remote repository
If you want to download the code to the local machine, `git` is the best way:
```shell
# For GitHub
git clone https://github.com/{insert_your_forked_repo}/mindspore.git
git remote add upstream https://github.com/mindspore-ai/mindspore.git
# For Gitee
git clone https://gitee.com/{insert_your_forked_repo}/mindspore.git
git remote add upstream https://gitee.com/mindspore/mindspore.git
```
* Develop code locally
To avoid inconsistency between multiple branches, checking out to a new branch is `SUGGESTED`:
```shell
git checkout -b {new_branch_name} origin/master
```
Then you can change the code arbitrarily.
* Push the code to the remote repository
After updating the code, you should push the update in the formal way:
```shell
git add .
git status # Check the update status
git commit -m "Your commit title"
git commit -s --amend #Add the concrete description of your commit
git push origin {new_branch_name}
```
* Pull a request to MindSpore repository
In the last step, your need to pull a compare request between your new branch and MindSpore `master` branch. After finishing the pull request, the Jekins CI will be automatically set up for building test.
### Report issues
A great way to contribute to the project is to send a detailed report when you encounter an issue. We always appreciate a well-written, thorough bug report, and will thank you for it!
When reporting issues, refer to this format:
- What version of env (mindspore, os, python etc) are you using?
- Is this a BUG REPORT or FEATURE REQUEST?
- What happened?
- What you expected to happen?
- How to reproduce it?(as minimally and precisely as possible)
- Special notes for your reviewers?
**Issues advisory:**
- **If you find an unclosed issue, which is exactly what you are going to solve,** please put some comments on that issue to tell others you would be in charge of it.
- **If an issue is opened for a while,** it's recommended for contributors to precheck before working on solving that issue.
- **If you resolve an issue which is reported by yourself,** it's also required to let others know before closing that issue.
### Propose PRs
* Raise your idea as an *issue* on [GitHub](https://github.com/mindspore-ai/mindspore/issues) or [Gitee](https://gitee.com/mindspore/mindspore/issues)
* If it is a new feature that needs lots of design details, a design proposal should also be submitted.
* After reaching consensus in the issue discussions and design proposal reviews, complete the development on the forked repo and submit a PR.
* None of PRs is not permitted until it receives **2+ LGTM** from approvers. Please NOTICE that approver is NOT allowed to add *LGTM* on his own PR.
* After PR is sufficiently discussed, it will get merged, abondoned or rejected depending on the outcome of the discussion.
**PRs advisory:**
- Any irrelevant changes should be avoided.
- Make sure your commit history being ordered.
- Always keep your branch up with the master branch.
- For bug-fix PRs, make sure all related issues being linked.

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

2
NOTICE Normal file
View File

@ -0,0 +1,2 @@
MindSpore
Copyright 2019-2020 Huawei Technologies Co., Ltd

124
README.md Normal file
View File

@ -0,0 +1,124 @@
![MindSpore Logo](docs/MindSpore-logo.png "MindSpore logo")
============================================================
- [What is MindSpore?](#what-is-MindSpore)
- [Automatic Differentiation](#automatic-differentiation)
- [Automatic Parallel](#automatic-parallel)
- [Installation](#installation)
- [Binaries](#binaries)
- [From Source](#from-source)
- [Quickstart](#quickstart)
- [Docs](#docs)
- [Community](#community)
- [Governance](#governance)
- [Communication](#communication)
- [Contributing](#contributing)
- [Release Notes](#release-notes)
- [License](#license)
## What Is MindSpore
MindSpore is a new open source deep learning training/inference framework that
could be used for mobile, edge and cloud scenarios. MindSpore is designed to
provide development experience with friendly design and efficient execution for
the data scientists and algorithmic engineers, native support for Ascend AI
processor, and software hardware co-optimization. At the meantime MindSpore as
a global AI open source community, aims to further advance the development and
enrichment of the AI software/hardware application ecosystem.
<img src="docs/MindSpore-architecture.png" alt="MindSpore Architecture" width="600"/>
For more details please check out our [Architecture Guide](https://www.mindspore.cn/docs/en/0.1.0-alpha/architecture.html).
### Automatic Differentiation
There are currently three automatic differentiation techniques in mainstream deep learning frameworks:
- **Conversion based on static compute graph**: Convert the network into a static data flow graph at compile time, then turn the chain rule into a data flow graph to implement automatic differentiation.
- **Conversion based on dynamic compute graph**: Record the operation trajectory of the network during forward execution in an operator overloaded manner, then apply the chain rule to the dynamically generated data flow graph to implement automatic differentiation.
- **Conversion based on source code**: This technology is evolving from the functional programming framework and performs automatic differential transformation on the intermediate expression (the expression form of the program during the compilation process) in the form of just-in-time compilation (JIT), supporting complex control flow scenarios, higher-order functions and closures.
TensorFlow adopted static calculation diagrams in the early days, whereas PyTorch used dynamic calculation diagrams. Static maps can utilize static compilation technology to optimize network performance, however, building a network or debugging it is very complicated. The use of dynamic graphics is very convenient, but it is difficult to achieve extreme optimization in performance.
But MindSpore finds another way, automatic differentiation based on source code conversion. On the one hand, it supports automatic differentiation of automatic control flow, so it is quite convenient to build models like PyTorch. On the other hand, MindSpore can perform static compilation optimization on neural networks to achieve great performance.
<img src="docs/Automatic-differentiation.png" alt="Automatic Differentiation" width="600"/>
The implementation of MindSpore automatic differentiation can be understood as the symbolic differentiation of the program itself. Because MindSpore IR is a functional intermediate expression, it has an intuitive correspondence with the composite function in basic algebra. The derivation formula of the composite function composed of arbitrary basic functions can be derived. Each primitive operation in MindSpore IR can correspond to the basic functions in basic algebra, which can build more complex flow control.
### Automatic Parallel
The goal of MindSpore automatic parallel is to build a training method that combines data parallelism, model parallelism, and hybrid parallelism. It can automatically select a least cost model splitting strategy to achieve automatic distributed parallel training.
<img src="docs/Automatic-parallel.png" alt="Automatic Parallel" width="600"/>
At present, MindSpore uses a fine-grained parallel strategy of splitting operators, that is, each operator in the figure is splited into a cluster to complete parallel operations. The splitting strategy during this period may be very complicated, but as a developer advocating Pythonic, you don't need to care about the underlying implementation, as long as the top-level API compute is efficient.
## Installation
### Binaries
MindSpore offers build options across multiple backends:
| Hardware Platform | Operating System | Status |
| :---------------- | :--------------- | :----- |
| Ascend910 | Ubuntu-x86 | ✔️ |
| | EulerOS-x86 | ✔️ |
| | EulerOS-aarch64 | ✔️ |
| GPU CUDA 9.2 | Ubuntu-x86 | ✔️ |
| GPU CUDA 10.1 | Ubuntu-x86 | ✔️ |
| CPU | Ubuntu-x86 | ✔️ |
For installation using pip, take `Ubuntu-x86` and `CPU` build version as an example:
1. Download whl from [MindSpore website](https://www.mindspore.cn/), and install the package.
```
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/cpu/ubuntu-x86/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
```
2. Run the following command to verify the install.
```
python -c 'import mindspore'
```
### From Source
[Install MindSpore](https://www.mindspore.cn/install/en).
## Quickstart
See the [Quick Start](https://www.mindspore.cn/tutorial/en/0.1.0-alpha/quick_start/quick_start.html)
to implement the image classification.
## Docs
More details about installation guide, tutorials and APIs, please see the
[User Documentation](https://gitee.com/mindspore/docs).
## Community
### Governance
Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/community/blob/master/governance.md).
### Communication
- [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers.
- IRC channel at `#mindspore` (only for meeting minutes logging purpose)
- Video Conferencing: meet.jit.si
- Mailing-list: https://mailweb.mindspore.cn/postorius/lists
## Contributing
Welcome contributions. See our [Contributor Wiki](CONTRIBUTING.md) for
more details.
## Release Notes
The release notes, see our [RELEASE](RELEASE.md).
## License
[Apache License 2.0](LICENSE)

73
RELEASE.md Normal file
View File

@ -0,0 +1,73 @@
# Release 0.1.0-alpha
## Main Features
### Ascend 910 Training and Inference Framework
* Recommended OS: Ubuntu 16.04 (or later) or EulerOS 2.5 or EulerOS 2.8
* Python version: 3.7.5
* Preset models
* ResNet-50: residual structure-based convolutional neural network (CNN) for image classification, which is widely used.
* AlexNet: classic CNN for image classification, achieving historical results in ImageNet LSVRC-2012.
* LeNet: classic CNN for image classification, which was proposed by Yann LeCun.
* VGG16: classic CNN for image classification, which was proposed by Oxford Visual Geometry Group.
* YoloV3: real-time object detection network.
* NEZHA: BERT-based Chinese pre-training network produced by Huawei Noah's Ark Laboratory.
* Execution modes
* Graph mode: provides graph optimization methods such as memory overcommitment, IR fusion, and buffer fusion to achieve optimal execution performance.
* PyNative mode: single-step execution mode, facilitating process debugging.
* Debugging capability and methods
* Save CheckPoints and Summary data during training.
* Support asynchronous printing.
* Dump the computing data.
* Support profiling analysis of the execution process performance.
* Distributed execution
* Support AllReduce, AllGather, and BroadCast collective communication.
* AllReduce data parallel: Each device obtains different training data, which accelerates the overall training process.
* Collective communication-based layerwise parallel: Models are divided and allocated to different devices to solve the problem of insufficient memory for large model processing and improve the training speed.
* Automatic parallel mode: The better data and model parallel mode can be predicted based on the cost model. It is recommended that this mode be used on ResNet series networks.
* Automatic differentiation
* Implement automatic differentiation based on Source to Source.
* Support distributed scenarios and automatic insertion of reverse communication operators.
* Data processing, augmentation, and save format
* Load common datasets such as ImageNet, MNIST, CIFAR-10, and CIFAR-100.
* Support common data loading pipeline operations, such as shuffle, repeat, batch, map, and sampler.
* Provide basic operator libraries to cover common CV scenarios.
* Support users to customize Python data augmentation operators through the Pyfunc mechanism.
* Support the access of user-defined datasets through the GeneratorDataset mechanism.
* Provide the MindSpore data format, data aggregation and storage, random access example, data partition, efficient parallel read, user-defined index, and dataset search.
* Convert user datasets to the MindSpore data format.
* After data processing and augmentation, provide training applications in feed and graph modes.
* FP32/16 mixed precision computation, supporting automatic and manual configuration
* Provide common operators such as nn, math, and array, which can be customized.
### Inference Deployment
* Deploy models in MindSpore format on the Ascend 310 platform for inference.
* Save models in ONNX format.
* Support saving models in LITE format and running models based on the lightweight inference framework.
* Recommended OS: Android 4.3 or later
* Supported network type: LeNet
* Provide the generalization operators generated by TVM and operators generated after specific networks are tuned.
### Other Hardware Support
* GPU platform training
* Recommended OS: Ubuntu 16.04
* CUDA version: 9.2 or 10.1
* CuDNN version: 7.6 or later
* Python version: 3.7.5
* NCCL version: 2.4.8-1
* OpenMPI version: 3.1.5
* Supported models: AlexNet, LeNet, and LSTM
* Supported datasets: MNIST and CIFAR-10
* Support data parallel.
* CPU platform training
* Recommended OS: Ubuntu 16.04
* Python version: 3.7.5
* Supported model: LeNet
* Supported dataset: MNIST
* Provide only the stand-alone operation version.
## Peripherals and Tools
* [MindSpore Official Website] (https://www.mindspore.cn/)
* [MindInsight Visualization Debugging and Optimization] (https://gitee.com/mindspore/mindinsight)
* [MindArmour Model Security Hardening Package] (https://gitee.com/mindspore/mindarmour)
* [GraphEngine Computational Graph Engine] (https://gitee.com/mindspore/graphengine)

14
SECURITY.md Normal file
View File

@ -0,0 +1,14 @@
# Security Risk Description
1. When MindSpore is used for AI model training, if the user-defined computational graph structure (for example, Python code for generating the MindSpore computational graph) is provided by an untrusted third party, malicious code may exist and will be loaded and executed to attack the system.
2. Model files are stored in binary mode. When MindSpore is used to optimize or infer AI models and the model files are loaded in deserialization mode, once malicious code is written into the model files, the code are loaded and executed, causing attacks on the system.
3. MindSpore performs only model training and inference based on the data provided by users. Users need to protect data security to avoid privacy leakage.
4. MindSpore is a distributed training platform. When MindSpore is used for distributed training, if an Ascend chip is used for training, a device provides a secure transmission protocol for gradient fusion. If GPUs or other clusters are used for training, identity authentication and secure transmission are not provided.
# Security Usage Suggestions
1. Run MindSpore in the sandbox.
2. Run MindSpore as a non-root user.
3. Ensure that the source of a computational graph structure is trustworthy. Do not write code irrelevant to model training in the network structure definition.
4. Ensure that the source of a network model is trustworthy or enter secure network model parameters to prevent model parameters from being tampered with.
5. Ensure that GPU distributed training is performed on an isolated cluster network.

File diff suppressed because it is too large Load Diff

19
autogen.sh Executable file
View File

@ -0,0 +1,19 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
git submodule update --init --recursive

468
build.sh Executable file
View File

@ -0,0 +1,468 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
BASEPATH=$(cd "$(dirname $0)"; pwd)
PROJECT_PATH="${BASEPATH}"
CUDA_PATH=""
CUDNN_PATH=""
export BUILD_PATH="${BASEPATH}/build/"
# print usage message
usage()
{
echo "Usage:"
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-s] [-b ge|cpu] [-m infer|train] \\"
echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\"
echo " [-P on|off] [-z] [-M on|off] [-V 9.2|10.1] [-I] [-K]"
echo ""
echo "Options:"
echo " -d Debug mode"
echo " -r Release mode, default mode"
echo " -v Display build command"
echo " -c Enable code coverage switch, default off"
echo " -t Run testcases switch, default on"
echo " -g Use glog to output log, default on"
echo " -h Print usage"
echo " -s Install or setup"
echo " -b Select other backend, available: \\"
echo " ge:graph engine, cpu"
echo " -m Select mode, available: infer, train, default is infer "
echo " -a Enable ASAN, default off"
echo " -p Enable pipeline profile, default off"
echo " -i Enable increment building, default off"
echo " -L Enable load ANF-IR as input of 'infer', default off"
echo " -R Enable the time_line record, default off"
echo " -j[n] Set the threads when building (Default: -j8)"
echo " -e Use gpu, d or cpu"
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
echo " -Q Enable dump end to end, default off"
echo " -D Enable dumping of function graph ir, default on"
echo " -z Compile dataset & mindrecord, default off"
echo " -M Enable MPI and NCCL for GPU training, default off"
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
echo " -I Compile predict, default off"
echo " -K Compile with AKG, default off"
}
# check value of input is 'on' or 'off'
# usage: check_on_off arg_value arg_name
check_on_off()
{
if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then
echo "Invalid value $1 for option -$2"
usage
exit 1
fi
}
# check and set options
checkopts()
{
# Init default values of build options
THREAD_NUM=8
DEBUG_MODE="off"
VERBOSE=""
ENABLE_COVERAGE="off"
RUN_TESTCASES="off"
EXECUTE_SETUP="off"
ENABLE_BACKEND=""
TRAIN_MODE="INFER"
ENABLE_ASAN="off"
ENABLE_PROFILE="off"
INC_BUILD="off"
ENABLE_LOAD_IR="off"
ENABLE_TIMELINE="off"
ENABLE_DUMP2PROTO="on"
ENABLE_DUMPE2E="off"
ENABLE_DUMP_IR="on"
COMPILE_MINDDATA="off"
ENABLE_MPI="off"
CUDA_VERSION="9.2"
COMPILE_PREDICT="off"
USE_GLOG="on"
PREDICT_PLATFORM=""
ENABLE_AKG="off"
# Process the options
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
d)
DEBUG_MODE="on"
;;
r)
DEBUG_MODE="off"
;;
v)
VERBOSE="VERBOSE=1"
;;
j)
THREAD_NUM=$OPTARG
;;
c)
check_on_off $OPTARG c
ENABLE_COVERAGE="$OPTARG"
;;
t)
check_on_off $OPTARG t
RUN_TESTCASES="$OPTARG"
;;
g)
check_on_off $OPTARG g
USE_GLOG="$OPTARG"
;;
h)
usage
exit 0
;;
s)
EXECUTE_SETUP="on"
;;
b)
if [[ "X$OPTARG" != "Xge" && "X$OPTARG" != "Xcpu" ]]; then
echo "Invalid value ${OPTARG} for option -b"
usage
exit 1
fi
ENABLE_BACKEND=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]')
if [[ "X$ENABLE_BACKEND" == "XGE" ]]; then
ENABLE_GE="on"
fi
if [[ "X$ENABLE_BACKEND" != "XCPU" ]]; then
ENABLE_CPU="on"
fi
;;
a)
check_on_off $OPTARG a
ENABLE_ASAN="$OPTARG"
;;
p)
check_on_off $OPTARG p
ENABLE_PROFILE="$OPTARG"
;;
i)
INC_BUILD="on"
;;
m)
if [[ "X$OPTARG" != "Xinfer" && "X$OPTARG" != "Xtrain" ]]; then
echo "Invalid value ${OPTARG} for option -m"
usage
exit 1
fi
TRAIN_MODE=$(echo "$OPTARG" | tr '[a-z]' '[A-Z]')
;;
L)
ENABLE_LOAD_IR="on"
echo "build with enable load anf ir"
;;
R)
ENABLE_TIMELINE="on"
echo "enable time_line record"
;;
e)
if [[ "X$OPTARG" == "Xgpu" ]]; then
ENABLE_GPU="on"
ENABLE_CPU="on"
elif [[ "X$OPTARG" == "Xd" ]]; then
ENABLE_D="on"
ENABLE_CPU="on"
elif [[ "X$OPTARG" == "Xcpu" ]]; then
ENABLE_CPU="on"
else
echo "Invalid value ${OPTARG} for option -e"
usage
exit 1
fi
;;
M)
check_on_off $OPTARG M
ENABLE_MPI="$OPTARG"
;;
V)
if [[ "X$OPTARG" != "X9.2" && "X$OPTARG" != "X10.1" ]]; then
echo "Invalid value ${OPTARG} for option -V"
usage
exit 1
fi
CUDA_VERSION="$OPTARG"
;;
P)
check_on_off $OPTARG p
ENABLE_DUMP2PROTO="$OPTARG"
echo "enable dump anf graph to proto file"
;;
Q)
check_on_off $OPTARG Q
ENABLE_DUMPE2E="$OPTARG"
echo "enable dump end to end"
;;
D)
check_on_off $OPTARG D
ENABLE_DUMP_IR="$OPTARG"
echo "enable dump function graph ir"
;;
z)
COMPILE_MINDDATA="on"
;;
I)
COMPILE_PREDICT="on"
if [[ "$OPTARG" == "arm64" ]]; then
PREDICT_PLATFORM="arm64"
elif [[ "$OPTARG" == "x86_64" ]]; then
PREDICT_PLATFORM="x86_64"
else
echo "-I parameter must be arm64 or x86_64"
exit 1
fi
;;
K)
ENABLE_AKG="on"
echo "enable compile with akg"
;;
*)
echo "Unknown option ${opt}!"
usage
exit 1
esac
done
}
checkopts "$@"
echo "---------------- mindspore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
build_exit()
{
echo "$@" >&2
stty echo
exit 1
}
# Create building path
build_mindspore()
{
echo "start build mindspore project."
mkdir -pv "${BUILD_PATH}/mindspore"
cd "${BUILD_PATH}/mindspore"
CMAKE_ARGS="-DDEBUG_MODE=$DEBUG_MODE -DBUILD_PATH=$BUILD_PATH"
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_LOAD_ANF_IR=$ENABLE_LOAD_IR"
if [[ "X$ENABLE_COVERAGE" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_COVERAGE=ON"
fi
if [[ "X$RUN_TESTCASES" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TESTCASES=ON"
fi
if [[ -n "$ENABLE_BACKEND" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${ENABLE_BACKEND}=ON"
fi
if [[ -n "$TRAIN_MODE" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${TRAIN_MODE}=ON"
fi
if [[ "X$ENABLE_ASAN" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON"
fi
if [[ "X$ENABLE_PROFILE" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PROFILE=ON"
fi
if [[ "X$ENABLE_TIMELINE" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_TIMELINE=ON"
fi
if [[ "X$ENABLE_DUMP2PROTO" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_PROTO=ON"
fi
if [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_E2E=ON"
fi
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_IR=${ENABLE_DUMP_IR^^}"
if [[ "X$ENABLE_MPI" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MPI=ON"
fi
if [[ "X$ENABLE_D" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON"
fi
if [[ "X$ENABLE_GPU" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GPU=ON -DCUDA_PATH=$CUDA_PATH -DCUDNN_PATH=$CUDNN_PATH -DMS_REQUIRE_CUDA_VERSION=${CUDA_VERSION}"
fi
if [[ "X$ENABLE_CPU" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_CPU=ON"
fi
if [[ "X$COMPILE_MINDDATA" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MINDDATA=ON"
fi
if [[ "X$USE_GLOG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
fi
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
fi
echo "${CMAKE_ARGS}"
if [[ "X$INC_BUILD" = "Xoff" ]]; then
cmake ${CMAKE_ARGS} ../..
fi
make ${VERBOSE} -j$THREAD_NUM
if [[ "X$EXECUTE_SETUP" = "Xon" ]]; then
make install
fi
echo "success to build mindspore project!"
}
build_predict()
{
git submodule update --init --recursive third_party/incubator-tvm
echo "start build predict project"
git submodule update --init --recursive third_party/flatbuffers
git submodule update --init --recursive third_party/googletest
git submodule update --init --recursive third_party/protobuf
rm -rf "${BASEPATH}/predict/build"
mkdir -pv "${BASEPATH}/predict/build"
rm -rf "${BASEPATH}/predict/output"
mkdir -pv "${BASEPATH}/predict/output"
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
if [ "${ANDROID_NDK}" ]; then
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
else
echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r16b/ \e[0m"
exit 1
fi
fi
#build flatbuf
cd "${BASEPATH}/third_party/flatbuffers"
rm -rf build && mkdir -p build && cd build && cmake .. && make -j$THREAD_NUM
FLATC="${BASEPATH}"/third_party/flatbuffers/build/flatc
cd "${BASEPATH}"/predict/schema && mkdir -p "${BASEPATH}"/predict/schema/inner
find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b
find . -name "*.fbs" -print0 | xargs -0 "${FLATC}" -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o ${BASEPATH}/predict/schema/inner
# check LLVM_PATH
if [ "${LLVM_PATH}" == "" ]; then
echo "Please set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config"
exit
fi
#build tvm
tvm_open_source="${BASEPATH}/third_party/incubator-tvm"
tvm_kernel_build="${BASEPATH}/predict/module/tvm_kernel"
if [ ! -f "${tvm_kernel_build}"/incubator-tvm/build/libtvm.so ]; then
rm -fr "${tvm_kernel_build}"/incubator-tvm
cp -fr "${tvm_open_source}" "${tvm_kernel_build}"
mkdir -p "${tvm_kernel_build}"/incubator-tvm/build
patch -d "${tvm_kernel_build}"/incubator-tvm -p1 < "${BASEPATH}"/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch
cp "${tvm_kernel_build}"/lite/src/codegen/llvm/lite_rtfunc_reset.cc "${tvm_kernel_build}"/incubator-tvm/src/codegen/llvm/
cp "${tvm_open_source}"/cmake/config.cmake "${tvm_kernel_build}"/incubator-tvm
if [ "${LLVM_PATH}" ]; then
sed -i "s#set(USE_LLVM .*)#set(USE_LLVM \"${LLVM_PATH}\")#g" "${tvm_kernel_build}"/incubator-tvm/config.cmake
else
echo "need set LLVM_PATH in env for example export LLVM_PATH=/xxxx/bin/llvm-config"
fi
cd "${tvm_kernel_build}"/incubator-tvm/build
cmake ..
make -j$THREAD_NUM
else
cd "${tvm_kernel_build}"/incubator-tvm/build
make -j$THREAD_NUM
fi
#gen op
predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_x86"
predict_platform="x86"
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
predict_tvm_op_lib_path="${BASEPATH}/predict/module/tvm_kernel/build/lib_arm64"
predict_platform="arm64"
fi
need_get_libs=true
if [ -d "${predict_tvm_op_lib_path}" ]; then
file_list=$(ls "${predict_tvm_op_lib_path}")
if [ -n "${file_list}" ]; then
libstime=$(stat -c %Y "${predict_tvm_op_lib_path}"/* | sort -u | tail -n1)
pythontime=$(find "${BASEPATH}"/predict/module/tvm_kernel/lite/python/ -name "*.py" -exec stat -c %Y {} \; |
sort -u | tail -n1)
if [ "${libstime}" -ge "${pythontime}" ]; then
need_get_libs=false
else
rm -fr "${predict_tvm_op_lib_path}"
fi
fi
fi
if $need_get_libs; then
PYTHONPATH_OLD=${PYTHONPATH}
export PYTHONPATH="${tvm_kernel_build}/incubator-tvm/python:${tvm_kernel_build}/incubator-tvm/topi/python:${tvm_kernel_build}/incubator-tvm/nnvm/python:${tvm_kernel_build}/lite/python:"
cd "${BASEPATH}"/predict/module/tvm_kernel/lite/python/at_ops
python3 at_gen_strip.py ${predict_platform}
export PYTHONPATH=${PYTHONPATH_OLD}
fi
cd "${BASEPATH}/predict/build"
if [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
-DANDROID_NATIVE_API_LEVEL=android-19 -DANDROID_NDK="${ANDROID_NDK}" \
-DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" -DANDROID_STL="c++_shared" \
-DANDROID_ABI="arm64-v8a" -DENABLE_PREDICT_ARM64=ON -DANDROID_ALLOW_UNDEFINED_SYMBOLS=TRUE ..
elif [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
cmake ..
fi
make ${VERBOSE} -j$THREAD_NUM
if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
cd "${BASEPATH}/predict/build/test" && ./run_tests.sh
fi
# copy securec include files
mkdir -p "${BASEPATH}/predict/output/include/securec/include"
cp "${BASEPATH}"/third_party/securec/include/* "${BASEPATH}"/predict/output/include/securec/include
cd "${BASEPATH}/predict/output/"
if [[ "$PREDICT_PLATFORM" == "x86_64" ]]; then
tar -cf MSPredict-0.1.0-linux_x86_64.tar.gz include/ lib/ --warning=no-file-changed
elif [[ "$PREDICT_PLATFORM" == "arm64" ]]; then
tar -cf MSPredict-0.1.0-linux_aarch64.tar.gz include/ lib/ --warning=no-file-changed
fi
echo "success to build predict project!"
}
if [[ "X$COMPILE_PREDICT" = "Xon" ]]; then
build_predict
echo "---------------- mindspore: build end ----------------"
exit
else
build_mindspore
fi
if [[ "X$INC_BUILD" = "Xoff" ]]; then
if [[ "X$ENABLE_GE" = "Xon" ]]; then
bash "${PROJECT_PATH}/package.sh" ge
elif [[ "X$ENABLE_GPU" = "Xon" ]]; then
bash "${PROJECT_PATH}/package.sh" ms gpu
elif [[ "X$ENABLE_D" = "Xon" ]] || [[ "X$ENABLE_CPU" = "Xon" ]]; then
bash "${PROJECT_PATH}/package.sh" ms
else
bash "${PROJECT_PATH}/package.sh" debug
fi
fi
cp -rf ${BUILD_PATH}/package/mindspore/lib ${BUILD_PATH}/../mindspore
cp -rf ${BUILD_PATH}/package/mindspore/*.so ${BUILD_PATH}/../mindspore
if [[ -d "${BUILD_PATH}/package/build" ]]; then
rm -rf "${BUILD_PATH}/package/build"
fi
echo "---------------- mindspore: build end ----------------"

View File

@ -0,0 +1,70 @@
message(STATUS "ENABLE_GE set to FALSE, compiling GraphEngine")
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine)
message(STATUS "ge dir: ${GE_SOURCE_DIR}")
# download json headers, rather than whole repository
include(${GE_SOURCE_DIR}/cmake/ge_utils.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
# for CPU/GPU mode, find c_sec and slog from local prebuild
if (NOT ENABLE_D)
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH})
find_library(slog libslog.so ${GE_PREBUILD_PATH})
elseif (DEFINED ENV{D_LINK_PATH})
set(GE_LIB_PATH $ENV{D_LINK_PATH})
set(GE_SYS_ARCH "")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
# x86 ubuntu
set(GE_SYS_ARCH "x86_64")
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
# arm euleros
set(GE_SYS_ARCH "aarch64")
else()
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
find_library(slog libslog.so ${GE_LIB_PATH})
find_library(mmpa libmmpa.so ${GE_LIB_PATH})
find_library(runtime libruntime.so ${GE_LIB_PATH})
find_library(msprof libmsprof.so ${GE_LIB_PATH})
find_library(register libregister.so ${GE_LIB_PATH})
find_library(hccl libhccl.so ${GE_LIB_PATH})
find_library(cce libcce.so ${GE_LIB_PATH})
find_library(resource libresource.so ${GE_LIB_PATH})
else()
set(HIAI_INSTALLED_DIR /usr/local/HiAI)
set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64)
set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/runtime/lib64)
find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR})
find_library(slog libslog.so ${HIAI_DRIVER_DIR})
find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR})
find_library(cce libcce.so ${HIAI_RUNTIME_DIR})
find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR})
find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR})
find_library(msprof libmsprof.so ${HIAI_RUNTIME_DIR})
find_library(register libregister.so ${HIAI_RUNTIME_DIR})
find_library(resource libresource.so ${HIAI_RUNTIME_DIR})
endif()
# compile libraries from following directories
# this cmake file is called only when NOT ENABLE_GE is set
set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# force __FILE__ to show relative path of file, from source directory
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined")
add_subdirectory(${GE_SOURCE_DIR}/src/common/graph)
if(ENABLE_D)
add_subdirectory(${GE_SOURCE_DIR}/src/ge/common)
add_subdirectory(${GE_SOURCE_DIR}/src/ge/ge_runtime)
endif()
set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS})

View File

@ -0,0 +1,36 @@
# googletest library
#
#
# GTest_LIBRARY
#
if (NOT TARGET gtest)
set(BUILD_TESTING OFF CACHE BOOL "Disable glog test")
set(_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE})
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(_ms_tmp_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(_ms_tmp_CMAKE_MACOSX_RPATH ${CMAKE_MACOSX_RPATH})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(BUILD_SHARED_LIBS ON)
set(CMAKE_MACOSX_RPATH TRUE)
set(CMAKE_CXX_FLAGS "${SECURE_CXX_FLAGS}")
if (CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "5.0" AND CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64" AND SYSTEM_TYPE MATCHES "euleros")
# -D_GLIBCXX_USE_CXX11_ABI=0 added for the ABI incompatible for libtsdclient.so
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest ${CMAKE_BINARY_DIR}/googletest)
set(CMAKE_POSITION_INDEPENDENT_CODE ${_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
set(BUILD_SHARED_LIBS ${_ms_tmp_BUILD_SHARED_LIBS})
set(CMAKE_MACOSX_RPATH ${_ms_tmp_CMAKE_MACOSX_RPATH})
endif()
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/googletest/googletest/include)
set(GTEST_LIBRARY gtest)

View File

@ -0,0 +1,100 @@
# Protobuf
#
#
# PROTOBUF_LIBRARY - Link this to use protobuf
#
if (NOT TARGET protobuf::libprotobuf)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disable protobuf test")
set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "Gen shared library")
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/protobuf/cmake ${CMAKE_BINARY_DIR}/protobuf)
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
endif ()
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/protobuf/src)
set(PROTOBUF_LIBRARY protobuf::libprotobuf)
set(PROTOC_EXECUTABLE $<TARGET_FILE:protobuf::protoc>)
function(ms_protobuf_generate c_var h_var)
if(NOT ARGN)
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
add_custom_command(
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
DEPENDS protobuf::protoc ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction()
function(ms_protobuf_generate_py c_var h_var py_var)
if(NOT ARGN)
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})
set(${py_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
list(APPEND ${py_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND perl -pi -e "s/import (.+_pb2.*)/from . import \\1/" "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
COMMAND cp "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py" "${PROJECT_SOURCE_DIR}/mindspore/train/"
DEPENDS protobuf::protoc ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${c_var}} ${${h_var}} ${${py_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
set(${py_var} ${${py_var}} PARENT_SCOPE)
endfunction()

View File

@ -0,0 +1,19 @@
# securec library
#
#
# SECUREC_LIBRARY
#
if (NOT TARGET securec)
set(_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE ${CMAKE_POSITION_INDEPENDENT_CODE})
set(_ms_tmp_CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
set(CMAKE_C_FLAGS "${SECURE_CXX_FLAGS}")
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec ${CMAKE_BINARY_DIR}/securec)
set(CMAKE_POSITION_INDEPENDENT_CODE ${_ms_tmp_CMAKE_POSITION_INDEPENDENT_CODE})
set(CMAKE_C_FLAGS ${_ms_tmp_CMAKE_C_FLAGS})
endif()
include_directories(${CMAKE_CURRENT_LIST_DIR}/../third_party/securec/include)
set(SECUREC_LIBRARY securec)

View File

@ -0,0 +1,25 @@
# MS Utils
#
function(find_python_package out_inc out_lib)
# Use PYTHON_EXECUTABLE if it is defined, otherwise default to python
if ("${PYTHON_EXECUTABLE}" STREQUAL "")
set(PYTHON_EXECUTABLE "python3")
else()
set(PYTHON_EXECUTABLE "${PYTHON_EXECUTABLE}")
endif()
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())"
RESULT_VARIABLE result
OUTPUT_VARIABLE inc)
string(STRIP "${inc}" inc)
set(${out_inc} ${inc} PARENT_SCOPE)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import distutils.sysconfig as sysconfig; import os; print(os.path.join(sysconfig.get_config_var('LIBDIR'), sysconfig.get_config_var('LDLIBRARY')))"
RESULT_VARIABLE result
OUTPUT_VARIABLE lib)
string(STRIP "${lib}" lib)
set(${out_lib} ${lib} PARENT_SCOPE)
endfunction()

View File

@ -0,0 +1,5 @@
mindspore_add_pkg(dlpack
VER 0.2
HEAD_ONLY ./
URL https://github.com/dmlc/dlpack/archive/0acb731e0e43d15deee27b66f10e4c5b4e667913.zip
MD5 6b8093f17ad4e830d3c63eb3171c4b45)

View File

@ -0,0 +1,5 @@
mindspore_add_pkg(dmlc_core
VER 0.3
HEAD_ONLY ./
URL https://github.com/dmlc/dmlc-core/archive/808f485387f9a03f78fa9f1159f387d0d91b7a28.zip
MD5 ea36f94c57752bf40fb02dfc362f1ed9)

View File

@ -0,0 +1,12 @@
set(Eigen3_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(Eigen3_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(Eigen3
VER 3.3.7
URL https://gitlab.com/libeigen/eigen/-/archive/3.3.7/eigen-3.3.7.tar.gz
MD5 9e30f67e8531477de4117506fe44669b
CMAKE_OPTION -DBUILD_TESTING=OFF)
find_package(Eigen3 3.3.7 REQUIRED ${MS_FIND_NO_DEFAULT_PATH})
include_directories(${Eigen3_INC})
include_directories(${EIGEN3_INCLUDE_DIR})
set_property(TARGET Eigen3::Eigen PROPERTY IMPORTED_GLOBAL TRUE)
add_library(mindspore::eigen ALIAS Eigen3::Eigen)

View File

@ -0,0 +1,52 @@
set(flatbuffers_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(flatbuffers_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(flatbuffers
VER 1.11.0
LIBS flatbuffers
EXE flatc
URL https://github.com/google/flatbuffers/archive/v1.11.0.tar.gz
MD5 02c64880acb89dbd57eebacfd67200d8
CMAKE_OPTION -DFLATBUFFERS_BUILD_TESTS=OFF )
add_library(mindspore::flatbuffers ALIAS flatbuffers::flatbuffers)
add_executable(mindspore::flatc ALIAS flatbuffers::flatc)
include_directories(${flatbuffers_INC})
function(ms_build_flatbuffers source_schema_files
source_schema_dirs
custom_target_name
generated_output_dir)
set(total_schema_dirs "")
set(total_generated_files "")
set(FLATC mindspore::flatc)
foreach (schema_dir ${source_schema_dirs})
set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs})
endforeach()
foreach(schema ${source_schema_files})
get_filename_component(filename ${schema} NAME_WE)
if (NOT ${generated_output_dir} STREQUAL "")
set(generated_file ${generated_output_dir}/${filename}_generated.h)
add_custom_command(
OUTPUT ${generated_file}
COMMAND ${FLATC} --gen-mutable
--reflect-names --gen-object-api -o ${generated_output_dir}
${total_schema_dirs}
-c -b --reflect-types ${schema}
DEPENDS ${FLATC} ${schema}
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
COMMENT "Running C++ flatbuffers compiler on ${schema}" VERBATIM)
list(APPEND total_generated_files ${generated_file})
endif()
endforeach()
add_custom_target(${custom_target_name} ALL
DEPENDS ${total_generated_files})
if (NOT ${generated_output_dir} STREQUAL "")
include_directories(${generated_output_dir})
set_property(TARGET ${custom_target_name}
PROPERTY GENERATED_OUTPUT_DIR
${generated_output_dir})
endif()
endfunction()

View File

@ -0,0 +1,10 @@
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS}")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(glog
VER 0.4.0
LIBS glog
URL https://github.com/google/glog/archive/v0.4.0.tar.gz
MD5 0daea8785e6df922d7887755c3d100d0
CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON)
include_directories(${glog_INC})
add_library(mindspore::glog ALIAS glog::glog)

View File

@ -0,0 +1,13 @@
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(gtest
VER 1.8.0
LIBS gtest
URL https://github.com/google/googletest/archive/release-1.8.0.tar.gz
MD5 16877098823401d1bf2ed7891d7dce36
CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON
-DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON)
include_directories(${gtest_INC})
add_library(mindspore::gtest ALIAS gtest::gtest)
file(COPY ${gtest_LIBPATH}/libgtest.so DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest)
file(COPY ${gtest_LIBPATH}/libgtest_main.so DESTINATION ${CMAKE_BINARY_DIR}/googletest/googlemock/gtest)

View File

@ -0,0 +1,13 @@
set(jpeg_turbo_USE_STATIC_LIBS ON)
set(jpeg_turbo_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -D_FORTIFY_SOURCE=2 -O2")
set(jpeg_turbo_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
mindspore_add_pkg(jpeg_turbo
VER 2.0.4
LIBS jpeg
URL https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.4.tar.gz
MD5 44c43e4a9fb352f47090804529317c88
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DCMAKE_SKIP_RPATH=TRUE
)
include_directories(${jpeg_turbo_INC})
add_library(mindspore::jpeg_turbo ALIAS jpeg_turbo::jpeg)

View File

@ -0,0 +1,9 @@
set(nlohmann_json_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(nlohmann_json_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(nlohmann_json
VER 3.6.1
HEAD_ONLY ./
URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip
MD5 0dc903888211db3a0f170304cd9f3a89)
include_directories(${nlohmann_json_INC})
add_library(mindspore::json ALIAS nlohmann_json)

View File

@ -0,0 +1,16 @@
set(tiff_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -Wno-unused-result \
-Wno-unused-but-set-variable -fPIC -D_FORTIFY_SOURCE=2 -O2")
set(tiff_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -Wno-unused-result \
-Wno-unused-but-set-variable -fPIC -D_FORTIFY_SOURCE=2 -O2")
set(tiff_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
mindspore_add_pkg(tiff
VER 4.1.0
LIBS tiff
URL https://gitlab.com/libtiff/libtiff/-/archive/v4.1.0/libtiff-v4.1.0.tar.gz
MD5 21de8d35c1b21ac82663fa9f56d3350d
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -Djbig=OFF -Dlzma=OFF -Djpeg12=OFF -Dzstd=OFF -Dpixarlog=OFF
-Dold-jpeg=OFF -Dwebp=OFF -DBUILD_SHARED_LIBS=OFF)
message("tiff include = ${tiff_INC}")
message("tiff lib = ${tiff_LIB}")

View File

@ -0,0 +1,11 @@
set(mkl_dnn_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(mkl_dnn_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(mkl_dnn
VER 1.1.1
LIBS dnnl mkldnn
URL https://github.com/intel/mkl-dnn/archive/v1.1.1.tar.gz
MD5 d6a422b00459600bdc22242590953f38
CMAKE_OPTION -DDNNL_ARCH_OPT_FLAGS='' -DDNNL_CPU_RUNTIME='SEQ' -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF)
include_directories(${mkl_dnn_INC})
add_library(mindspore::dnnl ALIAS mkl_dnn::dnnl)
add_library(mindspore::mkldnn ALIAS mkl_dnn::mkldnn)

View File

@ -0,0 +1,12 @@
set(nccl_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(nccl
VER 2.4.8-1
LIBS nccl
URL https://github.com/NVIDIA/nccl/archive/v2.4.8-1.tar.gz
MD5 f14b37d6af1c79db5f57cb029a753727
BUILD_OPTION src.build NVCC_GENCODE="-gencode=arch=compute_70,code=sm_70"
INSTALL_INCS build/include/*
INSTALL_LIBS build/lib/*)
include_directories(${nccl_INC})
add_library(mindspore::nccl ALIAS nccl::nccl)

View File

@ -0,0 +1,11 @@
set(ompi_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(ompi
VER 3.1.5
LIBS mpi
URL https://github.com/open-mpi/ompi/archive/v3.1.5.tar.gz
MD5 f7f220b26532c11a2efbc0bb73af3282
PRE_CONFIGURE_COMMAND ./autogen.pl
CONFIGURE_COMMAND ./configure)
include_directories(${ompi_INC})
add_library(mindspore::ompi ALIAS ompi::mpi)

View File

@ -0,0 +1,5 @@
mindspore_add_pkg(ms_onnx
VER 1.6.0
HEAD_ONLY ./
URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz
MD5 512f2779d6215d4a36f366b6b9acdf1e)

View File

@ -0,0 +1,31 @@
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
mindspore_add_pkg(opencv
VER 4.2.0
LIBS opencv_core opencv_imgcodecs opencv_imgproc
URL https://github.com/opencv/opencv/archive/4.2.0.tar.gz
MD5 e8cb208ce2723481408b604b480183b6
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DWITH_PROTOBUF=OFF -DWITH_WEBP=OFF -DWITH_IPP=OFF -DWITH_ADE=OFF
-DBUILD_ZLIB=ON
-DBUILD_JPEG=ON
-DBUILD_PNG=ON
-DBUILD_OPENEXR=ON
-DBUILD_TESTS=OFF
-DBUILD_PERF_TESTS=OFF
-DBUILD_opencv_apps=OFF
-DCMAKE_SKIP_RPATH=TRUE
-DBUILD_opencv_python3=OFF
-DWITH_FFMPEG=OFF
-DWITH_TIFF=ON
-DBUILD_TIFF=OFF
-DWITH_JASPER=OFF
-DBUILD_JASPER=OFF
-DTIFF_INCLUDE_DIR=${tiff_INC}
-DTIFF_LIBRARY=${tiff_LIB})
include_directories(${opencv_INC}/opencv4)
add_library(mindspore::opencv_core ALIAS opencv::opencv_core)
add_library(mindspore::opencv_imgcodecs ALIAS opencv::opencv_imgcodecs)
add_library(mindspore::opencv_imgproc ALIAS opencv::opencv_imgproc)

View File

@ -0,0 +1,96 @@
mindspore_add_pkg(protobuf
VER 3.8.0
HEAD_ONLY ./
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
MD5 3d9e32700639618a4d2d342c99d4507a)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Disable protobuf test")
set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "Gen shared library")
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
add_subdirectory(${protobuf_DIRPATH}/cmake ${protobuf_DIRPATH}/build)
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
set(PROTOBUF_LIBRARY protobuf::libprotobuf)
include_directories(${protobuf_DIRPATH}/src)
add_library(mindspore::protobuf ALIAS libprotobuf)
function(ms_protobuf_generate c_var h_var)
if(NOT ARGN)
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
add_custom_command(
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
DEPENDS protobuf::protoc ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction()
function(ms_protobuf_generate_py c_var h_var py_var)
if(NOT ARGN)
message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files")
return()
endif()
set(${c_var})
set(${h_var})
set(${py_var})
foreach(file ${ARGN})
get_filename_component(abs_file ${file} ABSOLUTE)
get_filename_component(file_name ${file} NAME_WE)
get_filename_component(file_dir ${abs_file} PATH)
file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir})
list(APPEND ${c_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc")
list(APPEND ${h_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h")
list(APPEND ${py_var} "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.cc"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}.pb.h"
"${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/${rel_path}"
COMMAND protobuf::protoc -I${file_dir} --cpp_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND protobuf::protoc -I${file_dir} --python_out=${CMAKE_BINARY_DIR}/${rel_path} ${abs_file}
COMMAND perl -pi -e "s/import (.+_pb2.*)/from . import \\1/" "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py"
COMMAND cp "${CMAKE_BINARY_DIR}/${rel_path}/${file_name}_pb2.py" "${PROJECT_SOURCE_DIR}/mindspore/train/"
DEPENDS protobuf::protoc ${abs_file}
COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM )
endforeach()
set_source_files_properties(${${c_var}} ${${h_var}} ${${py_var}} PROPERTIES GENERATED TRUE)
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
set(${py_var} ${${py_var}} PARENT_SCOPE)
endfunction()

View File

@ -0,0 +1,12 @@
set(pybind11_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(pybind11_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(pybind11
VER 2.4.3
URL https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz
MD5 62254c40f89925bb894be421fe4cdef2
CMAKE_OPTION -DPYBIND11_TEST=OFF -DPYBIND11_LTO_CXX_FLAGS=FALSE
)
include_directories(${pybind11_INC})
find_package(pybind11 REQUIRED)
set_property(TARGET pybind11::module PROPERTY IMPORTED_GLOBAL TRUE)
add_library(mindspore::pybind11_module ALIAS pybind11::module)

View File

@ -0,0 +1,6 @@
mindspore_add_pkg(rang
VER 3.1.0
HEAD_ONLY ./
URL https://github.com/agauniyal/rang/archive/cabe04d6d6b05356fa8f9741704924788f0dd762.zip
MD5 0c5c9b251fea9ee7ce32f188655be0ea)

View File

@ -0,0 +1,15 @@
set(sqlite_USE_STATIC_LIBS ON)
set(sqlite_CXXFLAGS)
set(sqlite_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -D_FORTIFY_SOURCE=2 -O2")
set(sqlite_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
mindspore_add_pkg(sqlite
VER 3.31.1
LIBS sqlite3
URL https://github.com/sqlite/sqlite/archive/version-3.31.1.tar.gz
MD5 5f4e7b4016c15f4fb5855615279819da
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001
CONFIGURE_COMMAND ./configure --enable-shared=no --disable-tcl --disable-editline --enable-json1)
include_directories(${sqlite_INC})
add_library(mindspore::sqlite ALIAS sqlite::sqlite3)

View File

@ -0,0 +1,8 @@
set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(incubator_tvm_gpu
VER 0.6.0
HEAD_ONLY ./
URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz
MD5 9cbbd32545a776023acabbba270449fe)

View File

@ -0,0 +1,10 @@
set(incubator_tvm_predict_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(incubator_tvm_predict_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(incubator_tvm_predict
VER 0.6.0
HEAD_ONLY ./
URL https://github.com/apache/incubator-tvm/release/download/v0.6.0/apache-tvm-src-v0.6.0-incubating.tar.gz
MD5 2d77a005f0046d937b99c67de82f6438
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/predict/0001-RetBugFix-CustomRuntime_v06.patch)
include_directories(${incubator_tvm_predict_INC})
add_library(mindspore::incubator_tvm_predict ALIAS incubator_tvm_predict)

View File

@ -0,0 +1,59 @@
set(SECURE_CXX_FLAGS "")
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(SECURE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
endif()
set(_ms_tmp_CMAKE_CXX_FLAGS_F ${CMAKE_CXX_FLAGS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
include(cmake/utils.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pybind11.cmake)
MESSAGE("go to link flatbuffers")
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake)
if(USE_GLOG)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake)
endif()
find_package(Python3)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${CMAKE_SOURCE_DIR}/third_party)
if (ENABLE_CPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
endif()
if (ENABLE_GPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dlpack.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/dmlc_core.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/rang.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tvm_gpu.cmake)
endif()
if (ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
endif()
if (ENABLE_GE)
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include)
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external)
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external/graph)
else()
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc)
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/ops)
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external)
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external/graph)
endif()
if (ENABLE_MINDDATA)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sqlite.cmake)
endif()
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS_F})

109
cmake/options.cmake Normal file
View File

@ -0,0 +1,109 @@
option(ENABLE_D "Enable d" OFF)
option(ENABLE_GPU "Enable gpu" OFF)
option(ENABLE_CPU "Enable cpu" OFF)
option(ENABLE_GE "Enable graph engine as backend to execute" OFF)
option(ENABLE_MINDDATA "Enable minddata compile" OFF)
option(ENABLE_TRAIN "Enable ge train, default off(only infer)" OFF)
option(ENABLE_TESTCASES "Run testcases switch, default off" OFF)
option(DEBUG_MODE "Debug mode, default off" OFF)
option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs")
option(ENABLE_LOAD_ANF_IR "Enable load ANF-IR as input of 'infer' stage of pipeline" OFF)
option(ENABLE_COVERAGE "Enable code coverage report" OFF)
option(USE_GLOG "Use glog to output log" OFF)
option(ENABLE_PROFILE "Enable pipeline profile, default off" OFF)
option(ENABLE_TIMELINE "Enable time line record" OFF)
option(ENABLE_DUMP_PROTO "Enable dump anf graph to file in ProtoBuffer format, default on" ON)
option(ENABLE_DUMP_E2E "Enable dump e2e file, default on" OFF)
option(ENABLE_DUMP_IR "Enable dump funciton graph ir, default on" ON)
option(ENABLE_MPI "enable mpi" OFF)
option(ENABLE_AKG "enable akg" OFF)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack")
endif()
if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -Wsign-compare")
endif()
if (ENABLE_COVERAGE)
set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -ftest-coverage")
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}")
endif()
if (ENABLE_ASAN)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -static-libasan -fsanitize=undefined")
else()
set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fno-omit-frame-pointer -static-libsan -fsanitize=undefined")
endif()
endif()
if (DEBUG_MODE)
set(CMAKE_BUILD_TYPE "Debug")
else()
add_compile_definitions(MEM_REUSE_DEBUG)
set(CMAKE_BUILD_TYPE "Release")
endif()
if ((CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") OR (CMAKE_BUILD_TYPE STREQUAL Release))
set(PYBIND11_LTO_CXX_FLAGS FALSE)
endif()
if (NOT BUILD_PATH)
set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build")
endif()
if (ENABLE_GE OR ENABLE_D)
set(ENABLE_TDTQUE ON)
endif()
if (ENABLE_GPU)
set(ENABLE_GPUQUE ON)
endif()
if (ENABLE_GE)
add_compile_definitions(ENABLE_GE)
add_compile_definitions(CUSTOM_OP)
endif()
if(ENABLE_TRAIN)
add_compile_definitions(ENABLE_TRAIN=1)
else()
add_compile_definitions(ENABLE_TRAIN=0)
endif()
if(USE_GLOG)
add_compile_definitions(USE_GLOG)
endif()
if (ENABLE_PROFILE)
add_compile_definitions(ENABLE_PROFILE)
endif()
if (ENABLE_TIMELINE)
add_compile_definitions(ENABLE_TIMELINE)
endif()
if (ENABLE_LOAD_ANF_IR)
add_compile_definitions(ENABLE_LOAD_ANF_IR)
endif()
if (ENABLE_TESTCASES OR (NOT ENABLE_D AND NOT ENABLE_GE))
add_compile_definitions(NO_DLIB=1)
endif()
if(ENABLE_DUMP_IR)
add_compile_definitions(ENABLE_DUMP_IR)
endif(ENABLE_DUMP_IR)
if(ENABLE_MINDDATA)
add_compile_definitions(ENABLE_MINDDATA)
if (ENABLE_TDTQUE)
add_compile_definitions(ENABLE_TDTQUE)
endif()
endif()
if(ENABLE_DUMP_E2E)
add_compile_definitions(ENABLE_DUMP_E2E)
endif()

320
cmake/utils.cmake Normal file
View File

@ -0,0 +1,320 @@
include(FetchContent)
set(FETCHCONTENT_QUIET OFF)
function(mindspore_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj)
add_subdirectory(${sub_dir})
if(NOT TARGET ${submodule_name_obj})
message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}")
endif()
if("$<TARGET_OBJECTS:${submodule_name_obj}>" IN_LIST ${des_submodule_objs})
message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}")
endif()
set(${des_submodule_objs} ${${des_submodule_objs}} $<TARGET_OBJECTS:${submodule_name_obj}> PARENT_SCOPE)
endfunction()
get_filename_component(_MS_LIB_CACHE ~/.mslib REALPATH)
if (NOT EXISTS ${_MS_LIB_CACHE})
file(MAKE_DIRECTORY ${_MS_LIB_CACHE})
endif ()
# set(FETCHCONTENT_BASE_DIR ${_MS_LIB_CACHE})
# set(CMAKE_PREFIX_PATH ${_MS_LIB_CACHE})
if (DEFINED ENV{MSLIBS_SERVER})
set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER})
message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}")
endif ()
if(LOCAL_LIBS_SERVER)
if (NOT ENV{no_proxy})
set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}")
else()
string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS)
if (${IP_POS} EQUAL -1)
set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}")
endif ()
endif ()
endif()
function(__download_pkg pkg_name pkg_url pkg_md5)
if(LOCAL_LIBS_SERVER)
get_filename_component(_URL_FILE_NAME ${pkg_url} NAME)
set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url})
endif()
FetchContent_Declare(
${pkg_name}
URL ${pkg_url}
URL_HASH MD5=${pkg_md5}
)
FetchContent_GetProperties(${pkg_name})
message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}")
if(NOT ${pkg_name}_POPULATED)
FetchContent_Populate(${pkg_name})
set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE)
endif()
endfunction()
function(__find_pkg_then_add_target pkg_name pkg_exe)
unset(${pkg_name}_LIBS)
message("_FIND:${${pkg_name}_BASE_DIR}")
if(pkg_exe)
find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH)
if(NOT ${pkg_exe}_EXE)
return()
endif()
add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL)
set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES
IMPORTED_LOCATION ${${pkg_exe}_EXE}
)
message("found ${${pkg_exe}_EXE}")
endif()
foreach(_LIB_NAME ${ARGN})
set(_LIB_SEARCH_NAME ${_LIB_NAME})
set(_LIB_TYPE SHARED)
if (${pkg_name}_USE_STATIC_LIBS)
set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}")
set(_LIB_TYPE STATIC)
endif ()
set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND)
find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/lib NO_DEFAULT_PATH)
if(NOT ${_LIB_NAME}_LIB)
return()
endif()
add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL)
set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include"
IMPORTED_LOCATION ${${_LIB_NAME}_LIB}
)
list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME})
message("found ${${_LIB_NAME}_LIB}")
STRING( REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB})
set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL)
endforeach(_LIB_NAME)
set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE)
endfunction()
function(__exec_cmd)
set(options )
set(oneValueArgs WORKING_DIRECTORY)
set(multiValueArgs COMMAND)
cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )
execute_process(COMMAND ${EXEC_COMMAND}
WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY}
RESULT_VARIABLE RESULT)
if(NOT RESULT EQUAL "0")
message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}")
endif()
endfunction()
function(__check_patches pkg_patches)
# check patches
if (PKG_PATCHES)
file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.md5)
file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${pkg_name}_PATCHES_MD5)
message("patches md5:${${pkg_name}_PATCHES_MD5}")
set(${pkg_name}_PATCHES_NEW_MD5 )
foreach(_PATCH ${PKG_PATCHES})
file(MD5 ${_PATCH} _PF_MD5)
set(${pkg_name}_PATCHES_NEW_MD5 "${${pkg_name}_PATCHES_NEW_MD5},${_PF_MD5}")
endforeach(_PATCH)
if (NOT ${pkg_name}_PATCHES_MD5 STREQUAL ${pkg_name}_PATCHES_NEW_MD5)
set(${pkg_name}_PATCHES ${PKG_PATCHES})
file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild")
file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.md5 ${${pkg_name}_PATCHES_NEW_MD5})
message("patches changed : ${${pkg_name}_PATCHES_NEW_MD5}")
endif ()
endif ()
endfunction()
set(MS_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH
NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH
NO_CMAKE_SYSTEM_PACKAGE_REGISTRY)
set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE)
function(mindspore_add_pkg pkg_name )
set(options )
set(oneValueArgs URL MD5 VER EXE DIR HEAD_ONLY)
set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} )
set(__FIND_PKG_NAME ${pkg_name})
string(TOLOWER ${pkg_name} pkg_name)
message("pkg name:${__FIND_PKG_NAME},${pkg_name}")
set(${pkg_name}_PATCHES_HASH )
foreach(_PATCH ${PKG_PATCHES})
file(MD5 ${_PATCH} _PF_MD5)
set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_MD5}")
endforeach(_PATCH)
# check options
set(${pkg_name}_CONFIG_TXT
"${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION}
${ARGN} - ${${pkg_name}_USE_STATIC_LIBS}- ${${pkg_name}_PATCHES_HASH}
${${pkg_name}_CXXFLAGS}--${${pkg_name}_CFLAGS}--${${pkg_name}_LDFLAGS}")
string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT})
string(MD5 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT})
message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}")
set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${${pkg_name}_CONFIG_HASH})
set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL)
if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY)
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE)
add_library(${pkg_name} INTERFACE)
target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC})
return()
endif ()
if(NOT PKG_EXE)
set(PKG_EXE 0)
endif()
set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR})
set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE)
if (PKG_LIBS)
__find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS})
if(${pkg_name}_LIBS)
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
message("Found libs: ${${pkg_name}_LIBS}")
return()
endif()
elseif(NOT PKG_HEAD_ONLY)
find_package(${__FIND_PKG_NAME} ${PKG_VER} ${MS_FIND_NO_DEFAULT_PATH})
if (${__FIND_PKG_NAME}_FOUND)
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
message("Found pkg: ${__FIND_PKG_NAME}")
return()
endif ()
endif ()
if (NOT PKG_DIR)
__download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5})
else()
set(${pkg_name}_SOURCE_DIR ${PKG_DIR})
endif ()
file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT})
message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}")
foreach(_PATCH_FILE ${PKG_PATCHES})
message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_PATCH_FILE}")
execute_process(COMMAND patch -p1 INPUT_FILE ${_PATCH_FILE}
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}
RESULT_VARIABLE Result)
if(NOT Result EQUAL "0")
message(FATAL_ERROR "Failed patch: ${_PATCH_FILE}")
endif()
endforeach(_PATCH_FILE)
file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600)
if(NOT ${pkg_name}_LOCK_RET EQUAL "0")
message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}")
endif()
if(${pkg_name}_SOURCE_DIR)
if (PKG_HEAD_ONLY)
file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*)
file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR})
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE)
add_library(${pkg_name} INTERFACE)
target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC})
elseif (PKG_CMAKE_OPTION)
# in cmake
file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
if (${pkg_name}_CFLAGS)
set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}")
endif ()
if (${pkg_name}_CXXFLAGS)
set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}")
endif ()
if (${pkg_name}_LDFLAGS)
if (${pkg_name}_USE_STATIC_LIBS)
#set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}")
else()
set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}")
endif ()
endif ()
__exec_cmd(COMMAND ${CMAKE_COMMAND} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR}
${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS}
-DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ..
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
__exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j8
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
else()
if (${pkg_name}_CFLAGS)
set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}")
endif ()
if (${pkg_name}_CXXFLAGS)
set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}")
endif ()
if (${pkg_name}_LDFLAGS)
set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}")
endif ()
# in configure && make
if (PKG_PRE_CONFIGURE_COMMAND)
__exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND}
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
endif ()
if (PKG_CONFIGURE_COMMAND)
__exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND}
${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}
--prefix=${${pkg_name}_BASE_DIR}
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
endif ()
set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION})
if (NOT PKG_CONFIGURE_COMMAND)
set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION}
${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS})
endif ()
# build
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j8
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS)
file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS})
file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS})
file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include)
file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib)
else()
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
endif ()
endif ()
endif()
if (PKG_LIBS)
__find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIBS})
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
if(NOT ${pkg_name}_LIBS)
message(FATAL_ERROR "Can not find pkg: ${pkg_name}")
endif()
else()
find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET ${MS_FIND_NO_DEFAULT_PATH})
if (${__FIND_PKG_NAME}_FOUND)
set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE)
message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}")
return()
endif ()
endif ()
endfunction()

View File

@ -0,0 +1,22 @@
{
"DumpSettings": {
"enable": false,
"trans_flag": false,
"path": "/tmp/net/",
"net_name": "ResNet50",
"mode": 0,
"iteration": 0,
"kernels": ["TensorAdd"]
},
"DumpSettingsSpec": {
"enable": "true: dump enable false: dump disable",
"trans_flag": "true: trans to host format,false: not trans format",
"path": "the dump file folder",
"net_name": "net name eg:ResNet50",
"mode": "0: dump all kernels 1: dump kernels in kernels list",
"iteration": "0: all iteration others: specified iteration ",
"kernels": "kernel name list need to be dump"
},
"other": {}
}

View File

@ -0,0 +1,22 @@
{
"DumpSettings": {
"enable": false,
"trans_flag": false,
"path": "/tmp/hccllog/0",
"net_name": "ResNet50",
"mode": 0,
"iteration": 0,
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
},
"DumpSettingsSpec": {
"enable": "true: dump enable false: dump disable",
"trans_flag": "true: trans to host format,false: not trans format",
"path": "the dump file folder",
"net_name": "net name eg:ResNet50",
"mode": "0: dump all kernels 1: dump kernels in kernels list",
"iteration": "0: all iteration others: specified iteration ",
"kernels": "kernel name list need to be dump"
},
"other": {}
}

View File

@ -0,0 +1,22 @@
{
"DumpSettings": {
"enable": false,
"trans_flag": false,
"path": "/tmp/hccllog/1",
"net_name": "ResNet50",
"mode": 0,
"iteration": 0,
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
},
"DumpSettingsSpec": {
"enable": "true: dump enable false: dump disable",
"trans_flag": "true: trans to host format,false: not trans format",
"path": "the dump file folder",
"net_name": "net name eg:ResNet50",
"mode": "0: dump all kernels 1: dump kernels in kernels list",
"iteration": "0: all iteration others: specified iteration ",
"kernels": "kernel name list need to be dump"
},
"other": {}
}

View File

@ -0,0 +1,44 @@
{
"board_id": "0x3000",
"chip_info": "910",
"deploy_mode": "lab",
"group_count": "1",
"group_list": [
{
"device_num": "2",
"server_num": "1",
"group_name": "",
"instance_count": "2",
"instance_list": [
{
"devices": [
{
"device_id": "0",
"device_ip": "[device_ip]"
}
],
"rank_id": "0",
"server_id": "[sever_id]"
},
{
"devices": [
{
"device_id": "1",
"device_ip": "[device_ip]"
}
],
"rank_id": "1",
"server_id": "[sever_id]"
}
]
}
],
"para_plane_nic_location": "device",
"para_plane_nic_name": [
"eth0",
"eth1"
],
"para_plane_nic_num": "2",
"status": "completed"
}

View File

@ -0,0 +1,175 @@
{
"board_id": "0x0000",
"chip_info": "910",
"deploy_mode": "lab",
"group_count": "1",
"group_list": [{
"device_num": "16",
"server_num": "2",
"group_name": "",
"instance_count": "16",
"instance_list": [{
"devices": [{
"device_id": "0",
"device_ip": "[A_device_ip_0]"
}],
"rank_id": "0",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "1",
"device_ip": "[A_device_ip_1]"
}],
"rank_id": "1",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "2",
"device_ip": "[A_device_ip_2]"
}],
"rank_id": "2",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "3",
"device_ip": "[A_device_ip_3]"
}],
"rank_id": "3",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "4",
"device_ip": "[A_device_ip_4]"
}],
"rank_id": "4",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "5",
"device_ip": "[A_device_ip_5]"
}],
"rank_id": "5",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "6",
"device_ip": "[A_device_ip_6]"
}],
"rank_id": "6",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "7",
"device_ip": "[A_device_ip_7]"
}],
"rank_id": "7",
"server_id": "[server_id_A]"
},
{
"devices": [{
"device_id": "0",
"device_ip": "[B_device_ip_0]"
}],
"rank_id": "8",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "1",
"device_ip": "[B_device_ip_1]"
}],
"rank_id": "9",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "2",
"device_ip": "[B_device_ip_2]"
}],
"rank_id": "10",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "3",
"device_ip": "[B_device_ip_3]"
}],
"rank_id": "11",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "4",
"device_ip": "[B_device_ip_4]"
}],
"rank_id": "12",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "5",
"device_ip": "[B_device_ip_5]"
}],
"rank_id": "13",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "6",
"device_ip": "[B_device_ip_6]"
}],
"rank_id": "14",
"server_id": "[server_id_B]"
},
{
"devices": [{
"device_id": "7",
"device_ip": "[B_device_ip_7]"
}],
"rank_id": "15",
"server_id": "[server_id_B]"
}
]
}],
"para_plane_nic_location": "device",
"para_plane_nic_name": [
"eth0",
"eth1",
"eth2",
"eth3",
"eth4",
"eth5",
"eth6",
"eth7"
],
"para_plane_nic_num": "8",
"status": "completed",
"hccl_config_json_spec": {
"board_id": "board id, current support x0000 or 0x3000",
"chip_info": "chip info, current is 910",
"deploy_mode": "current use lab",
"group_count": "number of groups used",
"group_list": "detailed group information",
"device_num": "number of devices used, the value is the nth power of 2",
"server_num": "number of multiple machines, single machine is 1",
"group_name": "default is hccl_world_group or specified",
"instance_count": "number of instance used, generally equal to device_num",
"instance_list": "detailed instance information",
"device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num.if server_num greater than 1, the id can be restart from 0",
"device_ip": "ip corresponding to device_id",
"rank_id": "the first device must be 0 and then increase in order",
"server_id": "can be specified as the machine's ip address",
"para_plane_nic_location": "current use device",
"para_plane_nic_name": "network card corresponding to device ip",
"para_plane_nic_num": "number of network cards used",
"status": "current use completed"
}
}

View File

@ -0,0 +1,111 @@
{
"board_id": "0x0000",
"chip_info": "910",
"deploy_mode": "lab",
"group_count": "1",
"group_list": [{
"device_num": "8",
"server_num": "1",
"group_name": "",
"instance_count": "8",
"instance_list": [{
"devices": [{
"device_id": "0",
"device_ip": "[device_ip_0]"
}],
"rank_id": "0",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "1",
"device_ip": "[device_ip_1]"
}],
"rank_id": "1",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "2",
"device_ip": "[device_ip_2]"
}],
"rank_id": "2",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "3",
"device_ip": "[device_ip_3]"
}],
"rank_id": "3",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "4",
"device_ip": "[device_ip_4]"
}],
"rank_id": "4",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "5",
"device_ip": "[device_ip_5]"
}],
"rank_id": "5",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "6",
"device_ip": "[device_ip_6]"
}],
"rank_id": "6",
"server_id": "[server_id]"
},
{
"devices": [{
"device_id": "7",
"device_ip": "[device_ip_7]"
}],
"rank_id": "7",
"server_id": "[server_id]"
}
]
}],
"para_plane_nic_location": "device",
"para_plane_nic_name": [
"eth0",
"eth1",
"eth2",
"eth3",
"eth4",
"eth5",
"eth6",
"eth7"
],
"para_plane_nic_num": "8",
"status": "completed",
"hccl_config_json_spec": {
"board_id": "board id, current support x0000 or 0x3000",
"chip_info": "chip info, current is 910",
"deploy_mode": "current use lab",
"group_count": "number of groups used",
"group_list": "detailed group information",
"device_num": "number of devices used, the value is the nth power of 2",
"server_num": "single machine is 1",
"group_name": "default is hccl_world_group or specified",
"instance_count": "number of instance used, generally equal to device_num",
"instance_list": "detailed instance information",
"device_id": "designated davinic device id to use, values start from 0, but no more than single machine total device num",
"device_ip": "ip corresponding to device_id",
"rank_id": "the first device must be 0 and then increase in order",
"server_id": "can be specified as the machine's ip address",
"para_plane_nic_location": "current use device",
"para_plane_nic_name": "network card corresponding to device ip",
"para_plane_nic_num": "number of network cards used",
"status": "current use completed"
}
}

120
dbg_dump_parser.sh Executable file
View File

@ -0,0 +1,120 @@
#!/bin/bash
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -x
set -e
export SAVE_GRAPHS=YES
# print usage message
function usage()
{
echo "Usage:"
echo "bash $0 [-g] [-d] [-a] [-h] [-f file]"
echo "e.g. $0 -f 3_specialize.dat"
echo ""
echo "Options:"
echo " -g Generate ir file for debug"
echo " -d Debug dumped ir"
echo " -a Execute all steps, default"
echo " -f File to be parse"
echo " -h Print usage"
}
# check and set options
function checkopts()
{
# init variable
MODE_GEN=0
MODE_DBG=1
MODE_ALL=2
FILE_NAME="3_optimize.dat"
mode="${MODE_ALL}" # default execute all steps
# Process the options
while getopts 'gdaf:h' opt
do
case "${opt}" in
g)
mode="${MODE_GEN}"
;;
d)
mode="${MODE_DBG}"
;;
a)
mode="${MODE_ALL}"
;;
f)
FILE_NAME="$OPTARG"
if ! [ -f "${FILE_NAME}" ]; then
echo "File $FILE_NAME does not exist"
usage
exit 1
fi
;;
h)
usage
exit 0
;;
*)
echo "Unknown option ${opt}!"
usage
exit 1
esac
done
}
# init variable
# check options
checkopts "$@"
cd build/mindspore/
make -j8
cp -v mindspore/ccsrc/_c_expression.cpython-*.so ../../mindspore/
cd -
UT_NAME="./tests/ut/python/model/test_lenet.py::test_lenet5_train_sens"
#UT_NAME="./tests/python/ops/test_math_ops.py::test_matmul_grad"
#UT_NAME="./tests/python/exec/resnet_example.py::test_compile"
#UT_NAME="./tests/perf_test/test_bert_train.py::test_bert_train"
if [[ "${mode}" == "${MODE_GEN}" || "${mode}" == "${MODE_ALL}" ]]; then
rm -rf pkl_objs
mkdir -p pkl_objs
echo "MS_IR_PATH=$(pwd)/pkl_objs pytest -s ${UT_NAME}"
MS_IR_PATH=$(pwd)/pkl_objs/ pytest -s "${UT_NAME}"
#pytest -s $UT_NAME
# 1_resolve.dat
# 3_specialize.dat
# 4_simplify_data_structures.dat
# 5_opt.dat
# 6_opt2.dat
# 7_opt_ge_adaptor_special.dat
# 8_cconv.dat
# 9_validate.dat
cp "${FILE_NAME}" anf_ir_file.dbg
rm -rf pkl_objs.dbg
cp -rf pkl_objs pkl_objs.dbg
fi
if [[ "${mode}" == "${MODE_DBG}" || "${mode}" == "${MODE_ALL}" ]]; then
echo "MS_IR_FILE=$(pwd)/anf_ir_file.dbg MS_IR_PATH=$(pwd)/pkl_objs.dbg/ pytest -s ${UT_NAME}"
MS_IR_FILE=$(pwd)/anf_ir_file.dbg MS_IR_PATH=$(pwd)/pkl_objs.dbg/ pytest -s "${UT_NAME}"
fi

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

BIN
docs/Automatic-parallel.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

BIN
docs/MindSpore-logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

3
docs/README.md Normal file
View File

@ -0,0 +1,3 @@
# MindSpore Documentation
The MindSpore documentation is in the [MindSpore Docs](https://gitee.com/mindspore/docs) repository.

View File

@ -0,0 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
import mindspore.common.dtype as mstype
from mindspore.model_zoo.Bert_NEZHA import BertConfig
bert_train_cfg = edict({
'epoch_size': 10,
'num_warmup_steps': 0,
'start_learning_rate': 1e-4,
'end_learning_rate': 0.0,
'decay_steps': 1000,
'power': 10.0,
'save_checkpoint_steps': 2000,
'keep_checkpoint_max': 10,
'checkpoint_prefix': "checkpoint_bert",
'DATA_DIR': "/your/path/examples.tfrecord",
'SCHEMA_DIR': "/your/path/datasetSchema.json"
})
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
)

View File

@ -0,0 +1,98 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
1. Prepare data
Following the data preparation as in BERT, run command as below to get dataset for training:
python ./create_pretraining_data.py \
--input_file=./sample_text.txt \
--output_file=./examples.tfrecord \
--vocab_file=./your/path/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
2. Pretrain
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
"""
import os
import pytest
import numpy as np
from numpy import allclose
from config import bert_train_cfg, bert_net_cfg
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore._c_dataengine as deMap
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
from mindspore.nn.optim import Lamb
from mindspore import log as logger
_current_dir = os.path.dirname(os.path.realpath(__file__))
def create_train_dataset(batch_size):
"""create train dataset"""
# apply repeat operations
repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
"next_sentence_labels", "masked_lm_positions",
"masked_lm_ids", "masked_lm_weights"])
type_cast_op = deMap.TypeCastOp("int32")
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
return ds
def weight_variable(shape):
"""weight variable"""
np.random.seed(1)
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
return Tensor(ones)
def train_bert():
"""train bert"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
ds = create_train_dataset(bert_net_cfg.batch_size)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
start_learning_rate=bert_train_cfg.start_learning_rate, end_learning_rate=bert_train_cfg.end_learning_rate,
power=bert_train_cfg.power, warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
netwithgrads.set_train(True)
model = Model(netwithgrads)
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
if __name__ == '__main__':
train_bert()

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
alexnet_cfg = edict({
'num_classes': 10,
'learning_rate': 0.002,
'momentum': 0.9,
'epoch_size': 1,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 227,
'image_width': 227,
'save_checkpoint_steps': 1562,
'keep_checkpoint_max': 10,
})

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
AlexNet example tutorial
Usage:
python alexnet.py
with --device_target=GPU: After 20 epoch training, the accuracy is up to 80%
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from config import alexnet_cfg as cfg
from mindspore import context
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1):
"""
create dataset for train or test
"""
cifar_ds = ds.Cifar10Dataset(data_path)
rescale = 1.0 / 255.0
shift = 0.0
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op)
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.repeat(repeat_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
return cifar_ds

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
AlexNet example tutorial
Usage:
python alexnet.py
with --device_target=GPU: After 20 epoch training, the accuracy is up to 80%
"""
import argparse
import mindspore.nn as nn
from config import alexnet_cfg as cfg
from mindspore.model_zoo.alexnet import AlexNet
from dataset import create_dataset
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset(args.data_path)
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
AlexNet example tutorial
Usage:
python alexnet.py
with --device_target=GPU: After 20 epoch training, the accuracy is up to 80%
"""
import argparse
import mindspore.nn as nn
from config import alexnet_cfg as cfg
from mindspore.model_zoo.alexnet import AlexNet
from dataset import create_dataset
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
print("============== Starting Training ==============")
ds_train = create_dataset(args.data_path,
cfg.batch_size,
repeat_size)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
model.train(cfg.epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
mnist_cfg = edict({
'num_classes': 10,
'lr': 0.01,
'momentum': 0.9,
'epoch_size': 1,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 32,
'image_width': 32,
'save_checkpoint_steps': 1875,
'keep_checkpoint_max': 10,
})

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train and test lenet example ########################
test lenet according to model file:
python main.py --data_path /YourDataPath --ckpt_path Your.ckpt
"""
import os
import argparse
import mindspore.nn as nn
from dataset import create_dataset
from config import mnist_cfg as cfg
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset(os.path.join(args.data_path, "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))

View File

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt) :
python main.py --data_path /YourDataPath
"""
import os
import argparse
import mindspore.nn as nn
from config import mnist_cfg as cfg
from dataset import create_dataset
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
ds_train = create_dataset(os.path.join(args.data_path, "train"), batch_size=cfg.batch_size,
repeat_size=repeat_size)
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"buffer_size": 100,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True,
"save_checkpoint_steps": 195,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"lr_init": 0.01,
"lr_end": 0.00001,
"lr_max": 0.1,
"warmup_epochs": 5,
"lr_decay_mode": "poly"
})

View File

@ -0,0 +1,83 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from config import config
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
device_num = int(os.getenv("DEVICE_NUM"))
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
num_shards=device_num, shard_id=rank_id)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4))
random_horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1))
resize_op = C.Resize((resize_height, resize_width))
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
change_swap_op = C.HWC2CHW()
trans = []
if do_train:
trans += [random_crop_op, random_horizontal_flip_op]
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans)
# apply shuffle operations
ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
import random
import numpy as np
from dataset import create_dataset
from config import config
from mindspore import context
from mindspore.model_zoo.resnet import resnet50
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from mindspore.communication.management import init
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
if args_opt.do_eval:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,77 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: DMINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp *.sh ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 &> log &
cd ..
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "infer" ];
then
rm -rf ./infer
fi
mkdir ./infer
cp *.py ./infer
cp *.sh ./infer
cd ./infer || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$1 --checkpoint_path=$2 &> log &
cd ..

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp *.py ./train
cp *.sh ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --dataset_path=$1 &> log &
cd ..

View File

@ -0,0 +1,96 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import argparse
import random
import numpy as np
from dataset import create_dataset
from lr_generator import get_lr
from config import config
from mindspore import context
from mindspore import Tensor
from mindspore.model_zoo.resnet import resnet50
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.dataset.engine as de
from mindspore.communication.management import init
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
lr_decay_mode='poly'))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)

View File

@ -0,0 +1,30 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as edict
cifar_cfg = edict({
"num_classes": 10,
"lr_init": 0.05,
"batch_size": 64,
"epoch_size": 70,
"momentum": 0.9,
"weight_decay": 5e-4,
"buffer_size": 10,
"image_height": 224,
"image_width": 224,
"keep_checkpoint_max": 10,
})

View File

@ -0,0 +1,63 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.common.dtype as mstype
from config import cifar_cfg as cfg
def create_dataset(data_home, repeat_num=1, training=True):
ds.config.set_seed(1)
data_dir = data_home + "/cifar-10-batches-bin"
if not training:
data_dir = data_home + "/cifar-10-verify-bin"
data_set = ds.Cifar10Dataset(data_dir)
resize_height = cfg.image_height
resize_width = cfg.image_width
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(rescale, shift)
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
data_set = data_set.map(input_columns="label", operations=type_cast_op)
data_set = data_set.map(input_columns="image", operations=c_trans)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=10)
# apply batch operations
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
return data_set

View File

@ -0,0 +1,54 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test vgg16 example on cifar10#################
python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore import context
import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import dataset
from mindspore.model_zoo.vgg import vgg16
from config import cifar_cfg as cfg
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.device_target != 'CPU' and args_opt.device_id:
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay)
model = Model(net, loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False), optimizer=opt, metrics={'acc'})
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
dataset = dataset.create_dataset(args_opt.data_path, 1, training=False)
res = model.eval(dataset)
print("result: ", res)

View File

@ -0,0 +1,80 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train vgg16 example on cifar10########################
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore import context
import numpy as np
import argparse
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
import dataset
from mindspore.model_zoo.vgg import vgg16
from config import cifar_cfg as cfg
import random
random.seed(1)
np.random.seed(1)
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.device_target != 'CPU' and args_opt.device_id:
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink = True, enable_loop_sink = True)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
model = Model(net, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False), optimizer=opt, metrics={'acc'})
dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size)
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])

View File

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for YOLOv3 models."""
class ConfigYOLOV3ResNet18:
"""
Config parameters for YOLOv3.
Examples:
ConfigYoloV3ResNet18.
"""
img_shape = [352, 640]
feature_shape = [32, 3, 352, 640]
num_classes = 80
backbone_input_shape = [64, 64, 128, 256]
backbone_shape = [64, 128, 256, 512]
backbone_layers = [2, 2, 2, 2]
backbone_stride = [1, 2, 2, 2]
ignore_threshold = 0.5
anchor_scales = [(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(163, 326)]
out_channel = int(len(anchor_scales) / 3 * (num_classes + 5))

View File

@ -0,0 +1,389 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOv3 dataset"""
from __future__ import division
import abc
import io
import os
import math
import json
import numpy as np
from PIL import Image
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.py_transforms as P
from config import ConfigYOLOV3ResNet18
iter_cnt = 0
_NUM_BOXES = 50
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
do_hsv = False
max_boxes = 20
num_classes = ConfigYOLOV3ResNet18.num_classes
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
"""Get true boxes."""
num_layers = anchors.shape[0] // 3
anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
true_boxes = np.array(true_boxes, dtype='float32')
# input_shape = np.array([in_shape, in_shape], dtype='int32')
input_shape = np.array(in_shape, dtype='int32')
boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
5 + num_classes), dtype='float32') for l in range(num_layers)]
anchors = np.expand_dims(anchors, 0)
anchors_max = anchors / 2.
anchors_min = -anchors_max
valid_mask = boxes_wh[..., 0] >= 1
wh = boxes_wh[valid_mask]
if len(wh) >= 1:
wh = np.expand_dims(wh, -2)
boxes_max = wh / 2.
boxes_min = -boxes_max
intersect_min = np.maximum(boxes_min, anchors_min)
intersect_max = np.minimum(boxes_max, anchors_max)
intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
box_area = wh[..., 0] * wh[..., 1]
anchor_area = anchors[..., 0] * anchors[..., 1]
iou = intersect_area / (box_area + anchor_area - intersect_area)
best_anchor = np.argmax(iou, axis=-1)
for t, n in enumerate(best_anchor):
for l in range(num_layers):
if n in anchor_mask[l]:
i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
k = anchor_mask[l].index(n)
c = true_boxes[t, 4].astype('int32')
y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
y_true[l][j, i, k, 4] = 1.
y_true[l][j, i, k, 5 + c] = 1.
pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
mask0 = np.reshape(y_true[0][..., 4:5], [-1])
gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
gt_box0 = gt_box0[mask0 == 1]
pad_gt_box0[:gt_box0.shape[0]] = gt_box0
mask1 = np.reshape(y_true[1][..., 4:5], [-1])
gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
gt_box1 = gt_box1[mask1 == 1]
pad_gt_box1[:gt_box1.shape[0]] = gt_box1
mask2 = np.reshape(y_true[2][..., 4:5], [-1])
gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
gt_box2 = gt_box2[mask2 == 1]
pad_gt_box2[:gt_box2.shape[0]] = gt_box2
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
"""Data augmentation function."""
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
iw, ih = image.size
ori_image_shape = np.array([ih, iw], np.int32)
h, w = image_size
if not is_training:
image = image.resize((w, h), Image.BICUBIC)
image_data = np.array(image) / 255.
if len(image_data.shape) == 2:
image_data = np.expand_dims(image_data, axis=-1)
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
image_data = image_data.astype(np.float32)
# correct boxes
box_data = np.zeros((max_boxes, 5))
if len(box) >= 1:
np.random.shuffle(box)
if len(box) > max_boxes:
box = box[:max_boxes]
# xmin ymin xmax ymax
box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
box_data[:len(box)] = box
else:
image_data, box_data = None, None
# preprocess bounding boxes
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(box_data, anchors, image_size)
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
ori_image_shape, gt_box1, gt_box2, gt_box3
flip = _rand() < .5
# correct boxes
box_data = np.zeros((max_boxes, 5))
while True:
# Prevent the situation that all boxes are eliminated
new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
_rand(1 - jitter, 1 + jitter)
scale = _rand(0.25, 2)
if new_ar < 1:
nh = int(scale * h)
nw = int(nh * new_ar)
else:
nw = int(scale * w)
nh = int(nw / new_ar)
dx = int(_rand(0, w - nw))
dy = int(_rand(0, h - nh))
if len(box) >= 1:
t_box = box.copy()
np.random.shuffle(t_box)
t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
if flip:
t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
t_box[:, 2][t_box[:, 2] > w] = w
t_box[:, 3][t_box[:, 3] > h] = h
box_w = t_box[:, 2] - t_box[:, 0]
box_h = t_box[:, 3] - t_box[:, 1]
t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
if len(t_box) >= 1:
box = t_box
break
box_data[:len(box)] = box
# resize image
image = image.resize((nw, nh), Image.BICUBIC)
# place image
new_image = Image.new('RGB', (w, h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image = new_image
# flip image or not
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
# convert image to gray or not
gray = _rand() < .25
if gray:
image = image.convert('L').convert('RGB')
# when the channels of image is 1
image = np.array(image)
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
# distort image
hue = _rand(-hue, hue)
sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
image_data = image / 255.
if do_hsv:
x = rgb_to_hsv(image_data)
x[..., 0] += hue
x[..., 0][x[..., 0] > 1] -= 1
x[..., 0][x[..., 0] < 0] += 1
x[..., 1] *= sat
x[..., 2] *= val
x[x > 1] = 1
x[x < 0] = 0
image_data = hsv_to_rgb(x) # numpy array, 0 to 1
image_data = image_data.astype(np.float32)
# preprocess bounding boxes
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
_preprocess_true_boxes(box_data, anchors, image_size)
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
ori_image_shape, gt_box1, gt_box2, gt_box3
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
def anno_parser(annos_str):
"""Annotation parser."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def expand_path(path):
"""Get file list from path."""
files = []
if os.path.isdir(path):
for file in os.listdir(path):
if os.path.isfile(os.path.join(path, file)):
files.append(file)
else:
raise RuntimeError("Path given is not valid.")
return files
def read_image(img_path):
"""Read image with PIL."""
with open(img_path, "rb") as f:
img = f.read()
data = io.BytesIO(img)
img = Image.open(data)
return np.array(img)
class BaseDataset():
"""BaseDataset for GeneratorDataset iterator."""
def __init__(self, image_dir, anno_path):
self.image_dir = image_dir
self.anno_path = anno_path
self.cur_index = 0
self.samples = []
self.image_anno_dict = {}
self._load_samples()
def __getitem__(self, item):
sample = self.samples[item]
return self._next_data(sample, self.image_dir, self.image_anno_dict)
def __len__(self):
return len(self.samples)
@staticmethod
def _next_data(sample, image_dir, image_anno_dict):
"""Get next data."""
image = read_image(os.path.join(image_dir, sample))
annos = image_anno_dict[sample]
return [np.array(image), np.array(annos)]
@abc.abstractmethod
def _load_samples(self):
"""Base load samples."""
class YoloDataset(BaseDataset):
"""YoloDataset for GeneratorDataset iterator."""
def _load_samples(self):
"""Load samples."""
image_files_raw = expand_path(self.image_dir)
self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
self.dataset_size = len(self.samples)
if self.dataset_size == 0:
raise RuntimeError("Valid dataset is none!")
def _filter_valid_data(self, anno_path, image_files_raw):
"""Filter valid data."""
image_files = []
anno_dict = {}
print("Start filter valid data.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8")
line_split = str(line_str).split(' ')
anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
anno_set = set(anno_dict.keys())
image_set = set(image_files_raw)
for image_file in (anno_set & image_set):
image_files.append(image_file)
self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
image_files.sort()
print("Filter valid data done!")
return image_files
class DistributedSampler():
"""DistributedSampler for YOLOv3"""
def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
num_replicas = 1
if rank is None:
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank % num_replicas
self.epoch = 0
self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
indices = indices.tolist()
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=8):
"""Creatr YOLOv3 dataset with GeneratorDataset."""
yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
ds.set_dataset_size(len(distributed_sampler))
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
hwc_to_chw = P.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.shuffle(buffer_size=256)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
EPOCH_SIZE=$2
IMAGE_DIR=$3
ANNO_PATH=$4
export MINDSPORE_HCCL_CONFIG_PATH=$5
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python ../train.py \
--distribute=1 \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--image_dir=$IMAGE_DIR \
--epoch_size=$EPOCH_SIZE \
--anno_path=$ANNO_PATH > log.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the scipt as: "
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4

View File

@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train YOLOv3 example ########################
train YOLOv3 and get network model files(.ckpt) :
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
"""
import argparse
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from dataset import create_yolo_dataset
from config import ConfigYOLOV3ResNet18
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate"""
lr_each_step = []
lr = learning_rate
for i in range(global_step):
if steps:
lr_each_step.append(lr * (decay_rate ** (i // decay_step)))
else:
lr_each_step.append(lr * (decay_rate ** (i / decay_step)))
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = lr_each_step[start_step:]
return lr_each_step
def init_net_param(net, init='ones'):
"""Init the parameters in net."""
params = net.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="YOLOv3")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--mode", type=str, default="graph", help="Run graph mode or feed mode, default is graph")
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.")
parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
init()
rank = args_opt.device_id
else:
context.set_context(enable_hccl=False)
rank = 0
device_num = 1
loss_scale = float(args_opt.loss_scale)
dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
init_net_param(net, "XavierUniform")
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
if args_opt.checkpoint_path != "":
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "graph":
dataset_sink_mode = True
print("Start train YOLOv3.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)

1
graphengine Submodule

@ -0,0 +1 @@
Subproject commit ff80316dab61b06ea15f8621d045b7699d2f0c79

28
mindspore/__init__.py Executable file
View File

@ -0,0 +1,28 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore package."""
from . import common, train
from .common import *
from .ops import _op_impl
from .train import *
from .log import *
from .version import __version__
__all__ = []
__all__.extend(__version__)
__all__.extend(common.__all__)
__all__.extend(train.__all__)
__all__.extend(log.__all__)

542
mindspore/_checkparam.py Normal file
View File

@ -0,0 +1,542 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Check parameters."""
import re
from enum import Enum
from itertools import repeat
from collections import Iterable
import numpy as np
from mindspore import log as logger
from .common import dtype as mstype
# Named string regular expression
_name_re = r"^\w+[0-9a-zA-Z\_\.]*$"
class Rel(Enum):
"""Numerical relationship between variables, logical relationship enumeration definition of range."""
# scalar compare
EQ = 1 # ==
NE = 2 # !=
LT = 3 # <
LE = 4 # <=
GT = 5 # >
GE = 6 # >=
# scalar range check
INC_NEITHER = 7 # (), include neither
INC_LEFT = 8 # [), include left
INC_RIGHT = 9 # (], include right
INC_BOTH = 10 # [], include both
# collection in, not in
IN = 11
NOT_IN = 12
@staticmethod
def get_strs(rel):
"""Get value from rel_strs."""
return rel_strs.get(rel, "")
@staticmethod
def get_fns(rel):
"""Get value from rel_fns."""
return rel_fns.get(rel, lambda *args: False)
rel_fns = {
# scalar compare
Rel.EQ: lambda x, y: x == y,
Rel.NE: lambda x, y: x != y,
Rel.LT: lambda x, y: x < y,
Rel.LE: lambda x, y: x <= y,
Rel.GT: lambda x, y: x > y,
Rel.GE: lambda x, y: x >= y,
# scalar range check
Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
# collection in, not in
Rel.IN: lambda x, y: x in y,
Rel.NOT_IN: lambda x, y: x not in y,
}
rel_strs = {
# scalar compare
Rel.EQ: "equal to {}",
Rel.NE: "not equal to {}",
Rel.LT: "less than {}",
Rel.LE: "less or equal to {}",
Rel.GT: "greater than {}",
Rel.GE: "greater or equal to {}",
# scalar range check
Rel.INC_NEITHER: "({}, {})",
Rel.INC_LEFT: "[{}, {})",
Rel.INC_RIGHT: "({}, {}]",
Rel.INC_BOTH: "[{}, {}]",
# collection in, not in
Rel.IN: "in {}",
Rel.NOT_IN: "not in {}",
}
class ParamValidator:
"""Parameter validator."""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isintance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be int!")
def check_int_positive(input_param):
"""Int type judgment."""
if isinstance(input_param, bool):
raise TypeError("Input type must be int cannot be bool!")
if isinstance(input_param, int):
if input_param > 0:
return input_param
raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param))
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
def check_int_non_negative(input_param):
"""Non_negative type judgment."""
if isinstance(input_param, bool):
raise TypeError("Input type must be int cannot be bool!")
if isinstance(input_param, int):
if input_param >= 0:
return input_param
raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param))
raise TypeError("Input type must be int cannot be {}!".format(type(input_param)))
def check_int_zero_one(input_param):
"""Judge whether it is 0 or 1."""
if input_param in (0, 1):
return input_param
raise ValueError("The data must be 0 or 1.")
def check_bool(input_param):
"""Bool type judgment."""
if isinstance(input_param, bool):
return input_param
raise TypeError("Input type must be bool!")
def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":
return input_param
raise ValueError("The data format must be NCHW.")
def check_padding(padding):
"""Check padding."""
if padding >= 0:
return padding
raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding))
def check_padmode(mode):
"""Check padmode."""
if mode in ("same", "valid", "pad"):
return mode
raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode))
def check_tensor_supported_type(dtype):
"""Check tensor dtype."""
if dtype in (mstype.int32, mstype.float32):
return dtype
raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype))
def _expand_tuple(n_dimensions):
"""To expand a number to tuple."""
def convert(m):
if not isinstance(m, tuple):
if isinstance(m, int):
return tuple(repeat(m, n_dimensions))
raise TypeError("Input type must be int or tuple.")
if not len(m) is n_dimensions:
raise TypeError("Input dimension is incorrect.")
for i in m:
if not isinstance(i, int):
raise TypeError("Incorrect type inside of a tuple!")
return m
return convert
def check_input_data(*data, data_class):
"""Input data check."""
for item in data:
if isinstance(item, (list, tuple)):
for v in item:
check_input_data(v, data_class=data_class)
else:
if not isinstance(item, data_class):
raise ValueError(f'Please provide as model inputs'
f' either a single'
f' or a list of {data_class.__name__},'
f' but got part data type is {str(type(item))}.')
if item.size() == 0:
msg = "Please provide non-empty data."
logger.error(msg)
raise ValueError(msg)
def check_output_data(data):
"""Output data check."""
if not data:
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
def check_axis_type_int(axis):
"""Check axis type."""
if not isinstance(axis, int):
raise TypeError('Wrong type for axis, should be int.')
def check_axis_range(axis, rank):
"""Check axis range."""
if not -rank <= axis < rank:
raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
def check_attr_int(attr_name, attr):
"""Check int type."""
if not isinstance(attr, int):
raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
def check_t_in_range(t):
"""Check input range."""
if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
once = _expand_tuple(1)
twice = _expand_tuple(2)
triple = _expand_tuple(3)
valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, bool, np.bool_)
def check_type(arg_name, arg_value, valid_types):
"""Check value type."""
# if input type is Tensor ,get element type
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# First, check if arg_value has argvalid_types
if isinstance(arg_value, tuple(valid_types)):
return type(arg_value).__name__
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
if isinstance(arg_value, (list, tuple)):
arg_value = np.array(arg_value)
# Thirdly, check the data type by numpy's dtype api
valid = False
if isinstance(arg_value, np.ndarray):
valid = arg_value.dtype in valid_data_types
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
valid = False
if not valid:
type_names = [t.__name__ for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {type(arg_value).__name__}.')
return type(arg_value).__name__
def check_typename(arg_name, arg_type, valid_types):
"""Check type name."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
if isinstance(arg_type, tuple(valid_types)):
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
def check_shape(arg_name, arg_value):
"""Check shape."""
# First, check if shape is a tuple
if not isinstance(arg_value, tuple):
raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},'
f' but got {type(arg_value).__name__}.')
# Second, wrap arg_value with numpy array so that it can be checked through numpy api
arg_value = np.array(arg_value)
# shape can not be ()
if arg_value.size == 0:
raise ValueError('Shape can not be empty.')
# shape's dimension should be 1
if arg_value.ndim != 1:
raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim))
# Thirdly, check each element's type of the shape
valid_types = (int, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)
for dim_size in arg_value:
if not isinstance(dim_size, valid_types) or dim_size <= 0:
raise ValueError('Every dimension size of the tensor shape should be a positive integer,'
' but got {}.'.format(dim_size))
def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if reg is None:
reg = _name_re
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True

View File

@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
from .utils import cell_attr_register
__all__ = ["cell_attr_register"]

View File

@ -0,0 +1,161 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""builtin_operations"""
import functools
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y):
"""Implement `scalar_add`."""
return x + y
def scalar_mul(x, y):
"""Implement `scalar_mul`."""
return x * y
def scalar_mod(x, y):
"""Implement `scalar_mul`."""
return x % y
def scalar_sub(x, y):
"""Implement `scalar_sub`."""
return x - y
def scalar_usub(x):
"""Implement `scalar_usub`."""
return -x
def tuple_getitem(x, index):
"""Implement `tuple_getitem`."""
if isinstance(x, Tensor):
x = x.asnumpy()
y = x[index]
return Tensor(y)
return x[index]
def scalar_gt(x, y):
"""Implement `scalar_gt`."""
return x > y
def scalar_ne(x, y):
"""Implement `scalar_ne`."""
return x != y
def scalar_eq(x, y):
"""Implement `scalar_eq`."""
return x == y
def scalar_le(x, y):
"""Implement `scalar_le`."""
return x <= y
def scalar_lt(x, y):
"""Implement `scalar_lt`."""
return x < y
def identity(x):
"""Implement `identity`."""
return x
def zeros_like_tensor(x):
"""Implement `zeros_like_tensor`."""
x = x.asnumpy()
value = Tensor(np.zeros(x.shape))
return value
def switch(c, x, y):
"""Implement `switch`."""
return x if c else y
def list_getitem(data, item):
"""Implement `list_getitem`."""
return data[item]
def bool_not(x):
"""Implement `bool_not`."""
return not x
def bool_and(x, y):
"""Implement `bool_and`."""
return x and y
def bool_or(x, y):
"""Implement `bool_or`."""
return x or y
def make_list(*xs):
"""Implement `make_list`."""
return list(xs)
def list_len(x):
"""Implement `list_len`."""
return len(x)
# only used in PyNative modes
def partial(*args):
"""Implement `partial`."""
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
# only used in PyNative modes
def depend(value, expr):
return value
def scalar_cast(x, t):
"""Implement scalar_cast."""
np_type = dtype_to_nptype(t)
value = np_type(x)
cast_value = np.ndarray.item(value)
return cast_value
def typeof(x):
"""Implement typeof."""
return get_py_obj_dtype(x)
def tuple_to_array(x):
"""Implement `tuple_to_array`."""
return Tensor(np.array(x))
def stop_gradient(x):
"""Implement `stop_gradient`."""
return x

View File

@ -0,0 +1,19 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""

View File

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import json
import math
import os
import subprocess
import sys
from multiprocessing import Pool
def _compiletask(platform, *jsons):
"""
compile func called in single process
Parameters:
platform: str. AKG platform or TBE platform
*jsons: str. json str contain kernel info, suitable for json compile
api
"""
if platform == "AKG":
p = __import__("akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
for json_item in jsons:
res = func(json_item)
if not res:
raise ValueError("Compile error")
if platform == "TBE":
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
for json_item in jsons:
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
if res.returncode != 0:
raise ValueError("Tbe compile error")
def compilekernelparallel(jsons, process, waitime):
"""
compile kernel use multi processes
Parameters:
jsons: list. json str list contain kernel info
process: int. processes num
waittime: int. max time the function blocked
"""
if not isinstance(jsons, list):
raise ValueError("jsons must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
jsons_akg = []
jsons_tbe = []
for json_ in jsons:
j = json.loads(json_)
if j["platform"] == "TBE":
jsons_tbe.append(json_)
continue
if j["platform"] == "AKG":
jsons_akg.append(json_)
continue
raise RuntimeError(
"not support this platform {0}".format(j["platform"]))
if jsons_akg:
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
else:
process_akg = 0
if process_akg == 0 and jsons_akg:
process_akg = 1
process_tbe = process-process_akg
if process_tbe == 0 and jsons_tbe:
process_tbe = 1
raise RuntimeWarning("we add a process for compile more operator")
args = [[] for _ in range(process_akg+process_tbe)]
args_lens = len(args)
for p in range(args_lens):
if p < process_tbe:
args[p].append("TBE")
else:
args[p].append("AKG")
jsons_tbe_lens = len(jsons_tbe)
for p in range(jsons_tbe_lens):
args[p % process_tbe].append(jsons_tbe[p])
jsons_akg_lens = len(jsons_akg)
for p in range(jsons_akg_lens):
args[process-p % process_akg-1].append(jsons_akg[p])
for p in range(args_lens):
args[p] = tuple(args[p])
with Pool(processes=process) as pool:
res = pool.starmap_async(_compiletask, args)
res.get(timeout=waitime)
return True

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,138 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""tbe common"""
import json
import os
class TBEException(Exception):
"""tbe exception class"""
def __init__(self, err_msg):
super().__init__(self)
self.__error_msg = err_msg
def __str__(self):
return self.__error_msg
def get_ddk_version():
"""get ddk version"""
ddk_version = os.environ.get("DDK_VERSION")
if ddk_version is None:
default_ddk_info_file = '/usr/local/HiAI/runtime/ddk_info'
backup_ddk_info_file = '/usr/local/Ascend/fwkacllib/ddk_info'
if os.path.exists(default_ddk_info_file):
with open(default_ddk_info_file, "r") as fp:
ddk_version = json.load(fp)["VERSION"]
elif os.path.exists(backup_ddk_info_file):
with open(backup_ddk_info_file, "r") as fp:
ddk_version = json.load(fp)["VERSION"]
else:
ddk_version = "1.60.T17.B830"
return ddk_version
def get_build_in_impl_path():
"""get build-in tbe implement path"""
tbe_impl_path = os.environ.get("TBE_IMPL_PATH")
if tbe_impl_path is None:
default_install_path = '/usr/local/HiAI/runtime/ops/op_impl/built-in/ai_core/tbe/impl/'
backup_install_path = '/usr/local/Ascend/Ascend/opp/op_impl/built-in/ai_core/tbe/impl/'
if os.path.exists(default_install_path):
tbe_impl_path = default_install_path
elif os.path.exists(backup_install_path):
tbe_impl_path = backup_install_path
if not tbe_impl_path:
raise ValueError("Can not find the env TBE_IMPL_PATH")
return tbe_impl_path
def _check_arg_info(item):
"""
Check parameter Validity.
Args:
item (dict): A dict, to be checked.
Raises:
Exception: If specific keyword is not found.
"""
if 'shape' not in item:
raise ValueError("Json string Errors, key:shape not found.")
if 'ori_shape' not in item:
raise ValueError("Json string Errors, key:ori_shape not found.")
if 'format' not in item or not item['format']:
raise ValueError("Json string Errors, key:format not found.")
if 'ori_format' not in item or not item['ori_format']:
raise ValueError("Json string Errors, key:ori_format not found.")
if 'dtype' not in item or not item['dtype']:
raise ValueError("Json string Errors, key:dtype not found.")
def get_args(op_info, arg_type):
"""
Parse args.
Args:
op_info (dict): Op info dict.
arg_type (str): arg, to be parsed.
Raises:
Exception: If specific keyword is not found.
"""
if arg_type not in op_info:
raise ValueError("Json string Errors, key:{} not found.".format(arg_type))
args = []
if not op_info[arg_type]:
return args
if arg_type in ['inputs', 'outputs']:
for item in op_info[arg_type]:
arg = []
for info in item:
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
if info['valid']:
_check_arg_info(info)
del info['valid']
del info['name']
if len(item) > 1:
arg.append(info)
else:
args.append(info)
else:
if len(item) > 1:
arg.append(None)
else:
args.append(None)
if len(item) > 1:
args.append(arg)
elif arg_type == 'attrs':
for item in op_info[arg_type]:
if 'value' not in item:
raise ValueError("Json string Errors, attr key:value not found.")
if item["name"] != "isRef":
args.append(item['value'])
return args
def check_kernel_info(kernel_info):
if 'op_info' not in kernel_info or not kernel_info['op_info']:
raise ValueError("Json string Errors, key:op_info not found.")
if 'name' not in kernel_info['op_info'] or not kernel_info['op_info']['name']:
raise ValueError("Json string Errors, key:name not found.")
if 'kernel_name' not in kernel_info['op_info'] or not kernel_info['op_info']['kernel_name']:
raise ValueError("Json string Errors, key:kernel_name not found.")

View File

@ -0,0 +1,166 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""tbe compiler"""
import json
import os
import sys
from te.platform.cce_conf import te_set_version
from te.platform.fusion_manager import op_build_cfg_dis, op_build_cfg_en, set_current_op_name, \
init_op_pattern, set_op_params, set_op_build_type, get_op_pattern, set_current_op_func_name
from te.platform.fusion_util import fusion_op
from common import check_kernel_info, get_args, get_build_in_impl_path, get_ddk_version
ddk_version = get_ddk_version()
build_in_impl_path = get_build_in_impl_path()
# op function list
op_build = "compile"
op_pre_build = "pre_build"
def _initialize(impl_path):
"""Initialize"""
te_set_version(ddk_version)
if impl_path == "":
op_module_name = build_in_impl_path
else:
op_module_name = impl_path
if not op_module_name:
raise ValueError("Can not find the env TBE_IMPL_PATH")
sys.path.insert(0, op_module_name)
def build_op(build_type, json_str):
"""
call op functions with function name and input args json_str
Args:
build_type : op function name
json_str (str): op function input args
Raises:
Exception: If specific keyword is not found.
"""
kernel_info = json.loads(json_str)
check_kernel_info(kernel_info)
# import module
op_name = kernel_info['op_info']['name']
try:
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(impl_path):
path, file_name = os.path.split(impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
else:
impl_path = ""
_initialize(impl_path)
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0)
# get function
if build_type == op_pre_build:
# set op parameter
op_build_cfg_dis()
set_current_op_func_name(op_name)
set_current_op_name(kernel_name)
init_op_pattern()
set_op_params(*outputs_args, *attrs_args, kernel_name=kernel_name)
set_op_build_type('prebuild')
if custom_flag:
py_fn_name = kernel_info['op_info']['name']
else:
py_fn_name = op_name
elif build_type == op_build:
if custom_flag:
py_fn_name = kernel_info['op_info']['name']
else:
py_fn_name = op_name
else:
raise ValueError("function {} is not supported by Tbe op {}.".format(build_type, op_name))
op_func = getattr(op_module, py_fn_name, None)
if op_func is None:
raise ValueError("Op:{} function {} is not supported by Tbe.".format(op_name, build_type))
# pre build
if build_type == op_pre_build:
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name)
# disable only pattern configuration
op_build_cfg_en()
return get_op_pattern()
# call function
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
except Exception as e:
if build_type == op_pre_build:
op_build_cfg_en()
raise RuntimeError(e)
def compile_fusion_op(json_str):
"""
compile fusion op with input args json_str
Args:
json_str (str): op function input args
Raises:
Exception: If specific keyword is not found.
"""
args = json.loads(json_str)
if 'fusion_op' not in args or not args['fusion_op']:
raise ValueError("Json string Errors, key:fusion_op not found.")
if 'prebuild_ops' not in args or not args['prebuild_ops']:
raise ValueError("Json string Errors, key:prebuild_ops not found.")
pre_build_op_list = args['prebuild_ops']
for op in pre_build_op_list:
build_op(op_pre_build, json.dumps(op))
fusion_op_arg = args['fusion_op']
return fusion_op(json.dumps(fusion_op_arg))
def compile_with_json(json_str):
"""
Compile tbe with json.
Args:
json_str (str): jason file path.
"""
json_info = json.loads(json_str)
if "fusion_op" in json_info:
ret = compile_fusion_op(json_str)
else:
ret = build_op(op_build, json_str)
return ret
if __name__ == "__main__":
in_args = sys.stdin.readline()
compile_with_json(in_args)

View File

@ -0,0 +1,210 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""tbe process"""
import traceback
import multiprocessing
import subprocess
import sys
import os
import json
from .common import check_kernel_info, get_args, get_build_in_impl_path
build_in_impl_path = get_build_in_impl_path()
def create_tbe_parallel_compiler():
"""
create TBEParallelCompiler object
Returns:
TBEParallelCompiler
"""
return compile_pool
def op_select_format(op_json: str):
"""
call op's op_select_format to get op supported format
Args:
op_json (str): json string of the op
Returns:
op supported format
"""
ret = ""
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "op_select_format"):
return ""
op_func = getattr(op_module, "op_select_format", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
return ret
def check_supported(op_json: str):
"""
call op's check_supported to check supported or not
Args:
op_json (str): json string of the op
Returns:
true or false
"""
ret = ""
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "check_supported"):
return ""
op_func = getattr(op_module, "check_supported", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
return ret
def run_compiler(op_json):
"""
run compiler to compile op with subprocess
Args:
op_json (str): json string of the op
Returns:
result type, result.
"""
try:
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
text=True, capture_output=True, check=True)
return "Success", "Success"
except subprocess.TimeoutExpired:
tb = traceback.format_exc()
return "TBEException", "CompileTimeOut: " + tb + "\ninput_args: " + op_json
except subprocess.CalledProcessError as e:
return "TBEException", "CompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
class CompilerPool:
"""compiler pool"""
def __init__(self):
processes = multiprocessing.cpu_count()
if processes > 16:
processes = 16
self.__pool = multiprocessing.Pool(processes=processes)
self.__next_task_id = 1
self.__running_tasks = []
def __del__(self):
if self.__pool is not None:
self.__pool.terminate()
self.__pool.join()
del self.__pool
def exit(self):
return
# self.__pool.terminate()
# self.__pool.join()
# if self.__pool is not None:
# del self.__pool
def start_compile_op(self, op_json):
"""
start compile op async.
Args:
op_json (str): json string of the op
Returns:
int, task id(>0). -1 if error
"""
task_id = self.__next_task_id
self.__next_task_id = self.__next_task_id + 1
task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,))
self.__running_tasks.append((task_id, task_future))
return task_id
def wait_one(self):
"""
wait until a compile task finish
Returns:
int, id of the finished task. -1 if error,0 if no unfinished task
str, result of compile task
"""
ret = 0, "Success"
if self.__running_tasks:
task_id, task_future = self.__running_tasks.pop(0)
ret_type, result = task_future.get(330)
if ret_type == "Success":
ret = task_id, "Success"
elif ret_type in ("Exception", "TBEException"):
ret = task_id, ret_type + ":" + result
else:
ret = task_id, "Exception: Not support return type:" + str(ret_type)
return ret
compile_pool = CompilerPool()

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Interfaces for parser module in c++.
"""
from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol,
get_dataclass_attributes, get_dataclass_methods,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name']

View File

@ -0,0 +1,110 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the namespace of parse."""
import builtins
class Namespace:
"""
Base class of namespace for resolve variables.
Args:
name (str): The namespace's name.
dicts (dict): A list of dict containing the namespace's variable.
"""
def __init__(self, name, *dicts):
self.name = name
self.dicts = dicts
def __contains__(self, name):
for d in self.dicts:
if name in d:
return True
return False
def __getitem__(self, name):
for d in self.dicts:
if name in d:
return d[name]
raise NameError(name)
def __repr__(self):
return f'Namespace:{self.name}'
class CellNamespace(Namespace):
"""
Namespace for Cell object.
Args:
name (str): Valid module name, it can be imported.
"""
def __init__(self, name):
mod_dict = vars(__import__(name, fromlist=['_']))
builtins_dict = vars(builtins)
super().__init__(name, mod_dict, builtins_dict)
def __getstate__(self):
return (self.name,)
def __setstate__(self, state):
name, = state
mod_dict = vars(__import__(name, fromlist=['_']))
builtins_dict = vars(builtins)
super().__init__(name, mod_dict, builtins_dict)
class ClosureNamespace(Namespace):
"""
Namespace for function closure.
Args:
fn (Function): A python function.
"""
def __init__(self, fn):
name = f'{fn.__module__}..<{fn.__name__}>'
names = fn.__code__.co_freevars
cells = fn.__closure__
ns = dict(zip(names, cells or ()))
super().__init__(name, ns)
def __getitem__(self, name):
d, = self.dicts
try:
return d[name].cell_contents
except ValueError:
raise UnboundLocalError(name)
class ClassMemberNamespace(Namespace):
"""
Namespace of a class's closure.
Args:
obj (Object): A python class object.
"""
def __init__(self, obj):
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
super().__init__(label, obj)
def __getitem__(self, name):
d, = self.dicts
try:
return getattr(d, name)
except ValueError:
raise UnboundLocalError(name)

View File

@ -0,0 +1,488 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The module of parser python object, called by c++."""
import ast
import types
import inspect
from textwrap import dedent
from dataclasses import is_dataclass
import asttokens
import mindspore.nn as nn
from mindspore import log as logger
from mindspore import ops
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
# define return value
RET_SUCCESS = 0
RET_FAILURE = 0xFF
# define resolve type
RESOLVE_TYPE_NONE = 0 # resolve None
RESOLVE_TYPE_FUNCTION = 1 # resolve function
RESOLVE_TYPE_METHOD = 2 # resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3 # resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4 # resolve the class instance of common class
RESOLVE_TYPE_INVALID = 0xFF
# define the class instance detail type
# When the type is RESOLVE_TYPE_CLASS_INSTANCE
CLASS_INSTANCE_TYPE_CELL = 0 # class instance type is Cell
CLASS_INSTANCE_TYPE_PRIMITIVE = 1 # class instance type is Primitive
CLASS_INSTANCE_TYPE_INVALID = 0xFF
# Ast main type
AST_MAIN_TYPE_STMT = 0 # ast.Stmt
AST_MAIN_TYPE_EXPR = 1 # ast.Expr
AST_MAIN_TYPE_SLICE = 2 # ast.Slice
AST_MAIN_TYPE_UNKNOWN = 0xFF # unknown
# Ast sub type
AST_SUB_TYPE_AND = 3 # ast.And
AST_SUB_TYPE_OR = 4 # ast.Or
AST_SUB_TYPE_NAME = 5 # ast.Name
AST_SUB_TYPE_TUPLE = 6 # ast.Tuple
AST_SUB_TYPE_SUBSCRIPT = 7 # ast.Subscript
AST_SUB_TYPE_STARRED = 8 # ast.Starred
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
# Process expr statement white list
# add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
parse_expr_statement_white_list = (
"append",
)
def parse_cb(func, parse_method=None):
"""Implements the function of parse."""
return Parser(func, parse_method)
def get_parse_method_of_class(obj, parse_method=None):
"""
Het parse method of class.
Args:
obj(Object): Instance of class.
parse_method(str): Save the method name. Cell object has default method named 'construct'.
Returns:
Function, obj's method.
"""
method = None
method_name = None
if parse_method is not None:
method_name = parse_method
else:
if isinstance(obj, nn.Cell):
method_name = "construct"
if method_name is not None:
if hasattr(obj, method_name):
method = getattr(obj, method_name)
return method
def get_bprop_method_of_class(obj, parse_method=None):
"""
Get bprop method of class.
Args:
obj (Object): Instance of class.
parse_method(str): Save the method name. Cell object has default method named 'bprop'.
Returns:
Function, obj's method.
"""
method = None
if isinstance(obj, nn.Cell):
method_name = "bprop"
if hasattr(obj, method_name):
method = getattr(obj, method_name)
return method
def resolve_symbol(namespace, symbol):
"""
Resolve a symbol.
Note:
Can't get function when use closure function. So save the fn on namespace.
Args:
namespace (Object): Symbol's namespace.
symbol (str): Need resolve symbol.
Returns:
Object, resolve result of symbol.
"""
# All exceptions need to be caught in this function
try:
resolve_ = namespace[symbol]
# list and dict is not hashable ,it can not be key for the map, just return the result
if isinstance(resolve_, (list, dict)):
return resolve_
# dataclass may not be hashable
if getattr(resolve_, "__hash__") is None:
return resolve_
# If need trope the obj
if resolve_ in convert_object_map:
resolve_ = convert_object_map.get(resolve_)
logger.debug("convert resolve = %r", resolve_)
if resolve_ == NO_IMPLEMENT:
raise NotImplementedError("not implemented for ", str(symbol))
except Exception as e:
if isinstance(e, NotImplementedError):
raise e
resolve_ = None
logger.debug("resolve exception occurred, value = %r", e)
logger.debug("resolve type is invalid, namespace = %s, symbol = %s",
namespace.__str__(), symbol)
if isinstance(resolve_, _MindSporeFunction):
logger.debug("resolve class _MindSporeFunction, resolve fn instead.")
resolve_ = resolve_.fn
return resolve_
def generate_scope(obj):
"""Generate the scope for every cell object in the network."""
if isinstance(obj, nn.Cell):
obj.generate_scope()
def get_scope_name(obj):
"""Returns the scope of a cell object in one network."""
if isinstance(obj, nn.Cell):
return obj.get_scope()
return None
def get_object_key(obj):
"""Return the function key: module + name."""
obj_key = ""
if hasattr(obj, "__name__"):
if hasattr(obj, "cell_init_args"):
obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
else:
if hasattr(obj, "cell_init_args"):
obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args)
obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj))
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
# method has same id of different instance
if isinstance(obj, types.MethodType):
method_instance = obj.__self__
instance_id = "%s_ID%d" % (str(method_instance.__class__.__name__), id(method_instance))
obj_id = instance_id + obj_id
return obj_id, obj_key
def is_class_member(node):
"""Check the attr is class member variable."""
type_ = node.__class__.__name__
if type_ == "Attribute":
if not hasattr(node.value, "id"):
return False
id_ = node.value.id
if id_ == "self":
return True
return False
def get_obj_type(obj):
"""Get the obj type."""
obj_type = RESOLVE_TYPE_INVALID
if obj is None:
obj_type = RESOLVE_TYPE_NONE
elif isinstance(obj, types.FunctionType):
obj_type = RESOLVE_TYPE_FUNCTION
elif isinstance(obj, types.MethodType):
obj_type = RESOLVE_TYPE_METHOD
elif isinstance(obj, type):
obj_type = RESOLVE_TYPE_CLASS_TYPE
elif _is_class_instance(obj):
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
else:
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
raise TypeError(f'Invalid object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
f'`{obj.shape if is_ndarray else obj}`.')
return obj_type
def get_class_instance_type(obj):
"""Get the class instance detail type."""
# check the obj type
logger.debug("Get the class type(%r)", obj)
class_type = CLASS_INSTANCE_TYPE_INVALID
if _is_class_instance(obj):
if isinstance(obj, nn.Cell):
class_type = CLASS_INSTANCE_TYPE_CELL
elif isinstance(obj, ops.Primitive):
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
# Add the other type base requirement
return class_type
def _is_class_instance(obj):
"""Confirm the obj is class instance."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
def _is_dataclass_instance(obj):
"""check whether a class is an instance of a dataclass (and not a dataclass itself)"""
return is_dataclass(obj) and not isinstance(obj, type)
def create_obj_instance(cls_type, args_tuple=None):
"""Create python instance."""
obj = None
if isinstance(cls_type, type):
# check the type, now only support nn.Cell and Primitive
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
if args_tuple is not None:
obj = cls_type(*args_tuple)
else:
obj = cls_type()
return obj
def get_module_namespace(obj):
"""Get the module's namespace."""
logger.debug("get module namespace, module = %r", obj)
mod_namespace = None
if isinstance(obj, types.ModuleType):
mod_namespace = CellNamespace(obj.__name__)
else:
logger.warning("Module(%r) is invalid, get namespace failure!", obj)
return mod_namespace
def get_class_member_namespace_symbol(obj):
"""Get obj class member type."""
logger.debug("get class instance namespace, object = %r", obj)
class_namespace = ClassMemberNamespace(obj)
logger.debug("class namesapce = %r", class_namespace)
return class_namespace
def get_dataclass_attributes(cls):
"""Get attributes of dataclass."""
fields = cls.__dataclass_fields__
attributes = {name: pytype_to_dtype(field.type)
for name, field in fields.items()}
return attributes
def get_dataclass_methods(cls):
"""Get functions of dataclass."""
methods = {name: getattr(cls, name)
for name in dir(cls)
if isinstance(getattr(cls, name), (types.FunctionType,))}
return methods
class Parser:
"""
Parser python code to ast tree.
Args:
fn(FunctionType/MethodType): Need parse object instance.
parse_method(ExtendInfoOfParseObj): Extend information for parse the function.
"""
def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None:
self.fn = fn
self.parse_method = parse_method
_, self.line_offset = inspect.getsourcelines(self.fn)
self.filename: str = inspect.getfile(self.fn)
# Used to resolve the function's globals Namespace.
self.global_namespace = CellNamespace(fn.__module__)
self.function_module = fn.__module__
# Used to resolve the function's nonlocals.
self.closure_namespace = ClosureNamespace(fn)
self.function_name = fn.__name__
self.col_offset = 0
def parse(self):
"""Parse the function or method."""
logger.debug("fn = %r", self.fn)
tree = None
if isinstance(self.fn, (types.FunctionType, types.MethodType)):
original_src = inspect.getsource(self.fn)
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("get source = %s", src)
tree = asttokens.ASTTokens(src, parse=True).tree
else:
logger.error("Fn type is invalid")
return tree
def get_args(self, node):
"""Get the arg of parse object."""
args = []
# process position args
for arg in node.args.args:
args.append(arg)
# process kwonlyargs: kwonlyargs is append after position args
if node.args.kwonlyargs:
for kwarg in node.args.kwonlyargs:
args.append(kwarg)
# process vararg: vararg is append after kwonlyargs
if node.args.vararg:
args.append(node.args.vararg)
# process kwarg: kwarg is append after vararg
if node.args.kwarg:
args.append(node.args.kwarg)
return args
def get_args_default_values(self, node):
"""get the args'default values of parse object."""
nondefaults = [None] * (len(node.args.args) - len(node.args.defaults))
defaults = nondefaults + node.args.defaults + node.args.kw_defaults
if node.args.vararg:
defaults.append(None)
if node.args.kwarg:
defaults.append(None)
return defaults
def get_node_type(self, node):
"""Process an ast node."""
method_name = f'{node.__class__.__name__}'
node_type = [method_name]
# judge the ast main type
if isinstance(node, ast.stmt):
node_type.append(AST_MAIN_TYPE_STMT)
elif isinstance(node, (ast.expr, ast.slice)) or node is None:
# ast.slice and ast.expr should be expr
node_type.append(AST_MAIN_TYPE_EXPR)
else:
node_type.append(AST_MAIN_TYPE_UNKNOWN)
return node_type
def get_ast_type(self, node):
"""Get the ast type."""
ast_type = AST_SUB_TYPE_UNKNOWN
if isinstance(node, ast.And):
ast_type = AST_SUB_TYPE_AND
elif isinstance(node, ast.Or):
ast_type = AST_SUB_TYPE_OR
elif isinstance(node, ast.Name):
ast_type = AST_SUB_TYPE_NAME
elif isinstance(node, ast.Tuple):
ast_type = AST_SUB_TYPE_TUPLE
elif isinstance(node, ast.Subscript):
ast_type = AST_SUB_TYPE_SUBSCRIPT
elif isinstance(node, ast.Starred):
ast_type = AST_SUB_TYPE_STARRED
else:
ast_type = AST_SUB_TYPE_UNKNOWN
return ast_type
def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol."""
if var in self.closure_namespace:
ops_info = (self.closure_namespace, var)
logger.debug("in closure_namespace")
elif var in self.global_namespace:
ops_info = (self.global_namespace, var)
logger.debug("in global_namespace")
else:
ops_info = parse_object_map.get(SYMBOL_UNDEFINE)
ops_info = [ops_info[0], var]
return ops_info
def get_operation_namespace_symbol(self, var: str):
"""Get operation namespace and symbol."""
ops_info = (trope_ns, var)
logger.debug("get operation ops info = %r", ops_info)
return ops_info
def get_ast_namespace_symbol(self, obj):
"""Get obj type and namespace and symbol."""
# step 1:get symbol from object map
ops_info = parse_object_map.get(type(obj), SYMBOL_UNDEFINE)
logger.debug("ops info = %r", ops_info)
return ops_info
def get_location(self, node):
"""
Get location of node start and end line no.
Args:
node: AST op node or tuple or List. This is a node in the ANF diagram,
here is the code location to get this node.
Returns:
List, [fileName, linestart, colstart, lineend, colend].
"""
ret = [self.filename]
err_exit = 0
if isinstance(node, (list, tuple)):
node_size = len(node)
if node_size == 0:
err_exit = 1
else:
start_node = node[0]
end_node = node[-1]
else:
start_node = node
end_node = node
if err_exit == 0:
if hasattr(start_node, "lineno") and \
hasattr(end_node, "col_offset"):
start_lineno, start_colno = start_node.first_token.start
end_lineno, end_colno = end_node.last_token.end
start_lineno += self.line_offset - 1
start_colno += self.col_offset
end_lineno += self.line_offset - 1
end_colno += self.col_offset
ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
else:
ret = ret + [0, 0, 0, 0]
return ret
def expand_expr_statement(self, node):
"""
Process the expr statement and expand it.
Returns:
tuple, (True, expr.value, x)/(False, None, None).
"""
if isinstance(node, ast.Expr) and hasattr(node, "value"):
expr_value = node.value
if isinstance(expr_value, ast.Call):
func = expr_value.func
if isinstance(func, ast.Attribute) and \
hasattr(func, "attr") and \
hasattr(func, "value"):
method = func.attr
target = func.value
if method in parse_expr_statement_white_list:
logger.debug("Expand expr, target:%s, method:%s", target, method)
return True, expr_value, target
return True, expr_value
return False, None, None

View File

@ -0,0 +1,137 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resources for ast tree parse."""
import ast
import math
from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C
from . import standard_method as M
from . import trope as T
from .namespace import CellNamespace
# namespace define
functional_ns = CellNamespace('mindspore.ops.functional')
composite_ns = CellNamespace('mindspore.ops.composite')
trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# ops map: {op.type:(Namespace, symbol)}
# Some space set aside for readability of code
parse_object_map = {
# ast grammar
ast.Add: (trope_ns, 'add'),
ast.Sub: (trope_ns, 'sub'),
ast.Mult: (trope_ns, 'mul'),
ast.Div: (trope_ns, 'truediv'),
ast.FloorDiv: (trope_ns, 'floordiv'),
ast.Mod: (trope_ns, 'mod'),
ast.Pow: (trope_ns, 'pow'),
ast.MatMult: (trope_ns, 'matmul'),
ast.LShift: (trope_ns, 'lshift'),
ast.RShift: (trope_ns, 'rshift'),
ast.BitAnd: (trope_ns, 'and_'),
ast.BitOr: (trope_ns, 'or_'),
ast.BitXor: (trope_ns, 'xor'),
ast.UAdd: (trope_ns, 'pos'),
ast.USub: (trope_ns, 'neg'),
ast.Invert: (trope_ns, 'invert'),
ast.Not: (trope_ns, 'not_'),
ast.Eq: (trope_ns, 'eq'),
ast.NotEq: (trope_ns, 'ne'),
ast.Lt: (trope_ns, 'lt'),
ast.Gt: (trope_ns, 'gt'),
ast.LtE: (trope_ns, 'le'),
ast.GtE: (trope_ns, 'ge'),
ast.Is: (trope_ns, 'is_'),
ast.IsNot: (trope_ns, 'is_not'),
ast.In: (trope_ns, 'contains'),
ast.NotIn: (trope_ns, 'not_contains'),
# operation symbol type
'getitem': (composite_ns, 'getitem'),
'ms_iter': (composite_ns, 'ms_iter'),
'ms_next': (composite_ns, 'ms_next'),
'hasnext': (composite_ns, 'hasnext'),
# undefined type
SYMBOL_UNDEFINE: (None, 'undefine'),
}
# convert map: {obj:(Namespace, symbol)}
# Escape an object to another object, eg: system function(len,xxx)
# Some space set aside for readability of code
convert_object_map = {
T.add: multitype_ops.add,
T.sub: multitype_ops.sub,
T.mul: multitype_ops.mul,
T.truediv: multitype_ops.div,
T.getitem: multitype_ops.getitem,
T.floordiv: NO_IMPLEMENT,
T.mod: F.scalar_mod,
T.pow: F.scalar_pow,
T.matmul: F.dot,
T.lshift: NO_IMPLEMENT,
T.rshift: NO_IMPLEMENT,
T.and_: F.bool_and,
T.or_: F.bool_or,
T.xor: NO_IMPLEMENT,
T.pos: F.scalar_uadd,
T.neg: multitype_ops.negative,
T.invert: NO_IMPLEMENT,
T.not_: F.bool_not,
T.eq: multitype_ops.equal,
T.ne: F.scalar_ne,
T.lt: multitype_ops.less,
T.gt: F.scalar_gt,
T.le: multitype_ops.less_equal,
T.ge: F.scalar_ge,
T.is_: F.is_,
T.is_not: F.is_not,
T.contains: NO_IMPLEMENT,
T.not_contains: NO_IMPLEMENT,
# system function
T.len: M.ms_len,
T.bool: M.bool_,
T.map: C.HyperMap(),
T.partial: F.partial,
T.zip: C.zip_operation,
# custom define operation
T.iter: M.ms_iter,
T.next: M.ms_next,
T.hasnext: M.hasnext,
T.setitem: M.setitem,
T.make_tuple: F.make_tuple,
T.make_dict: F.make_dict,
T.make_list: F.make_list,
T.make_slice: F.make_slice,
T.range: F.make_range,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,
math.exp: NO_IMPLEMENT,
math.log: F.scalar_log,
math.sin: NO_IMPLEMENT,
math.cos: NO_IMPLEMENT,
math.tan: NO_IMPLEMENT,
}

View File

@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The functions in this file is used to dump and load python object in anf graphs."""
import pickle
import os
import stat
def dump_obj(obj, path):
"""Dump object to file."""
file_name = hex(id(obj))
file_path = path + file_name
with open(file_path, 'wb') as f:
os.chmod(file_path, stat.S_IWUSR | stat.S_IRUSR)
pickle.dump(obj, f)
return file_name
def load_obj(file_path):
"""Load object from file."""
obj = None
try:
real_file_path = os.path.realpath(file_path)
except Exception as ex:
raise RuntimeError(ex)
with open(real_file_path, 'rb') as f:
obj = pickle.load(f)
return obj
__all__ = ['dump_obj', 'load_obj']

View File

@ -0,0 +1,235 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like
from ...ops.composite.base import _append
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
trans = P.Transpose()
def transpose(x):
"""Implementation of `transpose`."""
shape = F.shape(x)
length = F.tuple_len(shape)
perm = F.make_range(0, length)
revert_perm = F.tuple_reversed(perm)
out = trans(x, revert_perm)
return out
def getitem(data, item):
"""Implementation of `getitem`."""
return data.__getitem__(item)
def setitem(data, item, value):
"""Implementation of `setitem`."""
return data.__setitem__(item, value)
def ms_iter(xs):
"""Implementation of `iter`."""
return xs.__ms_iter__()
def ms_next(it):
"""Implementation of `next`."""
return it.__ms_next__()
def hasnext(it):
"""Implementation of `hasnext`."""
return it.__ms_hasnext__()
def ms_len(data):
"""Implementation of `len`."""
return data.__len__()
def floor(x):
"""Implementation of `floor`."""
return x.__floor__()
def trunc(x):
"""Implementation of `trunc`."""
return x.__trunc__()
def uadd(x):
"""Implementation of `uadd`."""
return x.__pos__()
def usub(x):
"""Implementation of `usub`."""
return x.__neg__()
def scalar_truediv(x, y):
"""Implementation of `scalar_truediv`."""
return x.__truediv__(y)
def scalar_floordiv(x, y):
"""Implementation of `scalar_floordiv`."""
return x.__floordiv__(y)
def bool_(x):
"""Implementation of `bool`."""
return x.__bool__()
def tensor_bool(x):
"""return immedate x, x is a tensor of bool value"""
return x
def and_(x, y):
"""Implementation of `and` (`&`)."""
return x.__and__(y)
def or_(x, y):
"""Implementation of `or` (`|`)."""
return x.__or__(y)
def matmul(x, y):
"""Implementation of `matmul` (`@`)."""
return x.__matmul__(y)
def float_bool(x):
"""Implementation of `float_bool`."""
return x != 0.0
def int_bool(x):
"""Implementation of `int_bool`."""
return x != 0
def str_bool(x):
"""Implementation of `str_bool`."""
if x == "":
return False
return True
def list_bool(x):
"""Implementation of `tuple_bool`."""
return len(x) != 0
def tuple_bool(x):
"""Implementation of `tuple_bool`."""
return len(x) != 0
def dict_bool(x):
"""Implementation of `dict_bool`."""
return len(x) != 0
def none_bool(x):
"""Implementation of `none_bool`."""
return False
def float_floordiv(x, y):
"""Implementation of `float_floordiv`."""
return floor(x / y)
#############
# Iteration #
#############
@dataclass(frozen=True)
class SequenceIterator:
"""
SequenceIterator is a util dataclass for iterating sequence object.
Iterator to use for sequences like List, Array.
"""
idx: int
seq: list
@core(ignore_values=True)
def __ms_hasnext__(self):
"""Whether the index is past the length of the sequence."""
return self.idx < ms_len(self.seq)
@core(ignore_values=True)
def __ms_next__(self):
"""Return the next element and a new iterator."""
return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq)
def list_iter(xs):
"""Iterator for List."""
return SequenceIterator(0, xs)
def array_iter(xs):
"""Iterator for Array."""
return SequenceIterator(0, xs)
def tuple_next(xs):
"""Next tuple."""
return xs[0], tail(xs)
def tuple_hasnext(xs):
"""Whether the tuple is empty or not."""
return len(xs) > 0
def list_next(xs):
"""Next list."""
return xs[0], tail(xs)
def list_hasnext(xs):
"""Whether the list is empty or not."""
return len(xs) > 0
def list_append(self_, item):
return _append(self_, item)
#################
# Array methods #
#################
def to_array(x):
"""Implementation of `to_array`."""
return x.__ms_to_array__()

View File

@ -0,0 +1,93 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trope some system function symbol to ops."""
# This operation function is not meant to be called directly
# support operator symbol, ast
from operator import ( # noqa
add, sub, mul, truediv, floordiv, mod, eq, ne, lt, gt, le, ge, pos, neg,
not_, and_, or_, xor, lshift, rshift, invert, is_, is_not, contains,
matmul, getitem, setitem
)
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip
)
# support functools
from functools import ( # noqa
partial
)
# support numpy symbol
from numpy import ( # noqa
exp, log, sin, cos, tan
)
__all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'gt', 'le', 'ge', 'pos', 'neg',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial',
'exp', 'log', 'sin', 'cos', 'tan']
def make_tuple(*elts): # pragma: no cover
"""Tuple builder."""
raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover
"""Dict builder."""
raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover
"""List builder."""
raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover
"""Slice builder."""
raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover
"""Range tuple builder."""
raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values."""
raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover
"""Hasnext function."""
raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x): # pragma: no cover
"""The to_array function."""
raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')

View File

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pynative mode help module."""
from inspect import signature
from functools import wraps
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""
def type_check(func):
sig = signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check

101
mindspore/_extends/utils.py Normal file
View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Some utils."""
import hashlib
import logging
import os
import inspect
from functools import wraps
def cal_sha256(file_path):
"""
Calculate sha256 value of input file.
Args:
file_path (str): file path to calculate.
Returns:
str, returns sha256 value of input file or empty str if cal failed.
"""
buf_size = 64 * 1024 # once read 64kb
sha256 = hashlib.sha256()
file_real_path = os.path.realpath(file_path)
if not os.path.isfile(file_real_path) or not os.access(file_real_path, os.R_OK):
return ""
_, extension = os.path.splitext(file_path)
try:
if extension == '.ptx':
with open(file_real_path, 'r') as kf:
while True:
data = kf.read(buf_size)
if not data:
break
sha256.update(data.encode("utf-8"))
else:
with open(file_real_path, 'rb') as kf:
while True:
data = kf.read(buf_size)
if not data:
break
sha256.update(data)
except IOError:
logging.error("Open file %s failed.", file_path)
return ""
return sha256.hexdigest()
def cell_attr_register(fn=None, attrs=None):
"""
Cell init attributes register.
Registering the decorator of the built-in operator cell __init__
function will add save all the parameters of __init__ as operator attributes.
Args:
fn (function): __init__ function of cell.
attrs (list(string) | string): attr list.
Returns:
function, original function.
"""
def wrap_cell(fn):
@wraps(fn)
def deco(self, *args, **kwargs):
arguments = []
if attrs is None:
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
arguments = bound_args.arguments
del arguments['self']
arguments = arguments.values()
fn(self, *args, **kwargs)
if attrs is not None:
if isinstance(attrs, list):
for item in attrs:
if not isinstance(item, str):
raise ValueError(f"attr must be a string")
if hasattr(self, item):
arguments.append(getattr(self, item))
elif isinstance(attrs, str):
if hasattr(self, attrs):
arguments = getattr(self, attrs)
else:
raise ValueError(f"attrs must be list or string")
self.cell_init_args = type(self).__name__ + str(arguments)
return deco
if fn is not None:
return wrap_cell(fn)
return wrap_cell

Some files were not shown because too many files have changed in this diff Show More