forked from mindspore-Ecosystem/mindspore
!24 Synchronization code to ms-incubator
Merge pull request !24 from changzherui/sy-code
This commit is contained in:
commit
22cc03a54a
|
@ -0,0 +1,26 @@
|
|||
<!-- Thanks for sending a pull request! Here are some tips for you:
|
||||
|
||||
If this is your first time, please read our contributor guidelines: https://gitee.com/mindspore/mindspore/blob/master/CONTRIBUTING.md
|
||||
-->
|
||||
|
||||
**What type of PR is this?**
|
||||
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
|
||||
>
|
||||
> /kind bug
|
||||
> /kind task
|
||||
> /kind feature
|
||||
|
||||
|
||||
**What does this PR do / why do we need it**:
|
||||
|
||||
|
||||
**Which issue(s) this PR fixes**:
|
||||
<!--
|
||||
*Automatically closes linked issue when PR is merged.
|
||||
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
|
||||
-->
|
||||
Fixes #
|
||||
|
||||
**Special notes for your reviewers**:
|
||||
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
name: RFC
|
||||
about: Use this template for the new feature or enhancement
|
||||
labels: kind/feature or kind/enhancement
|
||||
|
||||
---
|
||||
|
||||
## Background
|
||||
- Describe the status of the problem you wish to solve
|
||||
- Attach the relevant issue if have
|
||||
|
||||
## Introduction
|
||||
- Describe the general solution, design and/or pseudo-code
|
||||
|
||||
## Trail
|
||||
| No. | Task Description | Related Issue(URL) |
|
||||
| --- | ---------------- | ------------------ |
|
||||
| 1 | | |
|
||||
| 2 | | |
|
|
@ -0,0 +1,43 @@
|
|||
---
|
||||
name: Bug Report
|
||||
about: Use this template for reporting a bug
|
||||
labels: kind/bug
|
||||
|
||||
---
|
||||
|
||||
<!-- Thanks for sending an issue! Here are some tips for you:
|
||||
|
||||
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
|
||||
-->
|
||||
|
||||
## Environment
|
||||
### Hardware Environment(`Ascend`/`GPU`/`CPU`):
|
||||
> Uncomment only one ` /device <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
|
||||
>
|
||||
> `/device ascend`</br>
|
||||
> `/device gpu`</br>
|
||||
> `/device cpu`</br>
|
||||
|
||||
### Software Environment:
|
||||
- **MindSpore version (source or binary)**:
|
||||
- **Python version (e.g., Python 3.7.5)**:
|
||||
- **OS platform and distribution (e.g., Linux Ubuntu 16.04)**:
|
||||
- **GCC/Compiler version (if compiled from source)**:
|
||||
|
||||
## Describe the current behavior
|
||||
|
||||
|
||||
## Describe the expected behavior
|
||||
|
||||
|
||||
## Steps to reproduce the issue
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## Related log / screenshot
|
||||
|
||||
|
||||
## Special notes for this issue
|
||||
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
name: Task
|
||||
about: Use this template for task tracking
|
||||
labels: kind/task
|
||||
|
||||
---
|
||||
|
||||
## Task Description
|
||||
|
||||
|
||||
## Task Goal
|
||||
|
||||
|
||||
## Sub Task
|
||||
| No. | Task Description | Issue ID |
|
||||
| --- | ---------------- | -------- |
|
||||
| 1 | | |
|
||||
| 2 | | |
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
<!-- Thanks for sending a pull request! Here are some tips for you:
|
||||
|
||||
If this is your first time, please read our contributor guidelines: https://github.com/mindspore-ai/mindspore/blob/master/CONTRIBUTING.md
|
||||
-->
|
||||
|
||||
**What type of PR is this?**
|
||||
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
|
||||
>
|
||||
> `/kind bug`</br>
|
||||
> `/kind task`</br>
|
||||
> `/kind feature`</br>
|
||||
|
||||
**What does this PR do / why do we need it**:
|
||||
|
||||
|
||||
**Which issue(s) this PR fixes**:
|
||||
<!--
|
||||
*Automatically closes linked issue when PR is merged.
|
||||
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
|
||||
-->
|
||||
Fixes #
|
||||
|
||||
**Special notes for your reviewers**:
|
||||
|
|
@ -38,15 +38,17 @@ set(MS_CCSRC_BUILD_PATH ${BUILD_PATH}/mindspore/mindspore/ccsrc)
|
|||
|
||||
if (ENABLE_GE)
|
||||
link_directories(${CMAKE_SOURCE_DIR}/third_party/ge/lib)
|
||||
else()
|
||||
elseif(ENABLE_D OR ENABLE_TESTCASES)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_graphengine.cmake)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/external)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/inc/framework)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
|
|
|
@ -78,7 +78,7 @@ Please follow this style to make MindSpore easy to review, maintain and develop.
|
|||
|
||||
* Pull a request to MindSpore repository
|
||||
|
||||
In the last step, your need to pull a compare request between your new branch and MindSpore `master` branch. After finishing the pull request, the Jekins CI will be automatically set up for building test.
|
||||
In the last step, your need to pull a compare request between your new branch and MindSpore `master` branch. After finishing the pull request, the Jenkins CI will be automatically set up for building test.
|
||||
|
||||
### Report issues
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
## Main Features
|
||||
|
||||
### Ascend 910 Training and Inference Framework
|
||||
* Recommended OS: Ubuntu 16.04 (or later) or EulerOS 2.0
|
||||
* Recommended OS: Ubuntu 16.04 (or later) or EulerOS 2.5 or EulerOS 2.8
|
||||
* Python version: 3.7.5
|
||||
* Preset models
|
||||
* ResNet-50: residual structure-based convolutional neural network (CNN) for image classification, which is widely used.
|
||||
|
|
|
@ -38,19 +38,19 @@ elseif (DEFINED ENV{D_LINK_PATH})
|
|||
find_library(cce libcce.so ${GE_LIB_PATH})
|
||||
find_library(resource libresource.so ${GE_LIB_PATH})
|
||||
else()
|
||||
set(HIAI_INSTALLED_DIR /usr/local/HiAI)
|
||||
set(HIAI_DRIVER_DIR ${HIAI_INSTALLED_DIR}/driver/lib64)
|
||||
set(HIAI_RUNTIME_DIR ${HIAI_INSTALLED_DIR}/runtime/lib64)
|
||||
find_library(c_sec libc_sec.so ${HIAI_DRIVER_DIR})
|
||||
find_library(slog libslog.so ${HIAI_DRIVER_DIR})
|
||||
find_library(mmpa libmmpa.so ${HIAI_DRIVER_DIR})
|
||||
|
||||
find_library(cce libcce.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(hccl libhccl.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(runtime libruntime.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(msprof libmsprof.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(register libregister.so ${HIAI_RUNTIME_DIR})
|
||||
find_library(resource libresource.so ${HIAI_RUNTIME_DIR})
|
||||
# Ascend mode
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
find_library(c_sec libc_sec.so ${ASCEND_DRIVER_PATH})
|
||||
find_library(slog libslog.so ${ASCEND_DRIVER_PATH})
|
||||
find_library(mmpa libmmpa.so ${ASCEND_DRIVER_PATH})
|
||||
find_library(cce libcce.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(hccl libhccl.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(runtime libruntime.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(msprof libmsprof.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(register libregister.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(resource libresource.so ${ASCEND_RUNTIME_PATH})
|
||||
endif()
|
||||
|
||||
# compile libraries from following directories
|
||||
|
|
|
@ -40,7 +40,7 @@ if (ENABLE_GE)
|
|||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party/ge/include/external/graph)
|
||||
else()
|
||||
elseif(ENABLE_D OR ENABLE_TESTCASES)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/ops)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/graphengine/inc/external)
|
||||
|
|
|
@ -16,16 +16,34 @@ function(mindspore_add_submodule_obj des_submodule_objs sub_dir submodule_name_o
|
|||
|
||||
endfunction()
|
||||
|
||||
get_filename_component(_MS_LIB_CACHE ~/.mslib REALPATH)
|
||||
if (DEFINED ENV{MSLIBS_CACHE_PATH})
|
||||
set(_MS_LIB_CACHE $ENV{MSLIBS_CACHE_PATH})
|
||||
else()
|
||||
set(_MS_LIB_CACHE ${CMAKE_BINARY_DIR}/.mslib)
|
||||
endif ()
|
||||
message("MS LIBS CACHE PATH: ${_MS_LIB_CACHE}")
|
||||
|
||||
if (NOT EXISTS ${_MS_LIB_CACHE})
|
||||
file(MAKE_DIRECTORY ${_MS_LIB_CACHE})
|
||||
endif ()
|
||||
# set(FETCHCONTENT_BASE_DIR ${_MS_LIB_CACHE})
|
||||
# set(CMAKE_PREFIX_PATH ${_MS_LIB_CACHE})
|
||||
|
||||
if (DEFINED ENV{MSLIBS_SERVER})
|
||||
set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER})
|
||||
message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}")
|
||||
endif ()
|
||||
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(N)
|
||||
if (JOBS)
|
||||
set(THNUM ${JOBS})
|
||||
else()
|
||||
set(JOBS 8)
|
||||
if (${JOBS} GREATER ${N})
|
||||
set(THNUM ${N})
|
||||
endif()
|
||||
endif ()
|
||||
message("set make thread num: ${THNUM}")
|
||||
|
||||
if(LOCAL_LIBS_SERVER)
|
||||
if (NOT ENV{no_proxy})
|
||||
set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}")
|
||||
|
@ -287,7 +305,7 @@ function(mindspore_add_pkg pkg_name )
|
|||
-DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ..
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
|
||||
|
||||
__exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j8
|
||||
__exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${THNUM}
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build)
|
||||
|
||||
else()
|
||||
|
@ -318,7 +336,7 @@ function(mindspore_add_pkg pkg_name )
|
|||
${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS})
|
||||
endif ()
|
||||
# build
|
||||
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j8
|
||||
__exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j${THNUM}
|
||||
WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR})
|
||||
|
||||
if (PKG_INSTALL_INCS OR PKG_INSTALL_LIBS)
|
||||
|
|
|
@ -6,17 +6,17 @@
|
|||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["TensorAdd"]
|
||||
"kernels": ["Default/Conv2D-op2", "Default/TensorAdd-op10"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"enable": "true: dump enable, false: dump disable",
|
||||
"trans_flag": "true: trans to host format, false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
"mode": "0: dump all kernels, 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration, others: specified iteration ",
|
||||
"kernels": "op's full scope name which need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
|
@ -6,17 +6,17 @@
|
|||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
|
||||
"kernels": ["Default/Conv2D-op2", "Default/TensorAdd-op10"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"enable": "true: dump enable, false: dump disable",
|
||||
"trans_flag": "true: trans to host format, false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
"mode": "0: dump all kernels, 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration, others: specified iteration ",
|
||||
"kernels": "op's full scope name which need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
||||
}
|
|
@ -6,17 +6,17 @@
|
|||
"net_name": "ResNet50",
|
||||
"mode": 0,
|
||||
"iteration": 0,
|
||||
"kernels": ["AllReduce","BiasAddGrad","Conv2DBackpropFilter","SparseSoftmaxCrossEntropyWithLogits"]
|
||||
"kernels": ["Default/Conv2D-op2", "Default/TensorAdd-op10"]
|
||||
},
|
||||
|
||||
"DumpSettingsSpec": {
|
||||
"enable": "true: dump enable false: dump disable",
|
||||
"trans_flag": "true: trans to host format,false: not trans format",
|
||||
"enable": "true: dump enable, false: dump disable",
|
||||
"trans_flag": "true: trans to host format, false: not trans format",
|
||||
"path": "the dump file folder",
|
||||
"net_name": "net name eg:ResNet50",
|
||||
"mode": "0: dump all kernels 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration others: specified iteration ",
|
||||
"kernels": "kernel name list need to be dump"
|
||||
"mode": "0: dump all kernels, 1: dump kernels in kernels list",
|
||||
"iteration": "0: all iteration, others: specified iteration ",
|
||||
"kernels": "op's full scope name which need to be dump"
|
||||
},
|
||||
"other": {}
|
||||
}
|
||||
}
|
|
@ -12,20 +12,22 @@ RUN apt update \
|
|||
&& DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
vim \
|
||||
wget \
|
||||
curl \
|
||||
xz-utils \
|
||||
net-tools \
|
||||
openssh-client \
|
||||
git \
|
||||
subversion \
|
||||
ntpdate \
|
||||
tzdata \
|
||||
tcl \
|
||||
sudo
|
||||
sudo \
|
||||
bash-completion
|
||||
|
||||
# Install compile tools
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
zlibc \
|
||||
make \
|
||||
libgmp-dev \
|
||||
patch \
|
||||
|
@ -39,7 +41,8 @@ RUN echo "dash dash/sh boolean false" | debconf-set-selections
|
|||
RUN DEBIAN_FRONTEND=noninteractive dpkg-reconfigure dash
|
||||
|
||||
# Install python (v3.7.5)
|
||||
RUN apt install -y --no-install-recommends libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev libgdbm-dev liblzma-dev libreadline-dev \
|
||||
RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||
libgdbm-dev libgdbm-compat-dev liblzma-dev libreadline-dev libsqlite3-dev \
|
||||
&& cd /tmp \
|
||||
&& wget https://github.com/python/cpython/archive/v3.7.5.tar.gz \
|
||||
&& tar -xvf v3.7.5.tar.gz \
|
||||
|
@ -62,12 +65,12 @@ RUN mkdir -pv /root/.pip \
|
|||
&& echo "index-url=http://mirrors.aliyun.com/pypi/simple/" >> /root/.pip/pip.conf
|
||||
|
||||
# Install pip package
|
||||
RUN pip install numpy \
|
||||
&& pip install wheel \
|
||||
&& pip install nose \
|
||||
&& pip install pytest \
|
||||
&& pip install pytest-xdist \
|
||||
&& pip list
|
||||
RUN pip install --no-cache-dir \
|
||||
numpy \
|
||||
wheel \
|
||||
nose \
|
||||
pytest \
|
||||
pytest-xdist
|
||||
|
||||
# Install cmake (v3.14.1)
|
||||
RUN cd /tmp \
|
||||
|
@ -77,4 +80,4 @@ RUN cd /tmp \
|
|||
&& rm -f /tmp/cmake-3.14.1-Linux-x86_64.sh
|
||||
|
||||
# Install MindSpore cpu whl package
|
||||
RUN pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/cpu/ubuntu-x86/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
RUN pip install --no-cache-dir https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/cpu/ubuntu-x86/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
|
|
|
@ -12,20 +12,22 @@ RUN apt update \
|
|||
&& DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
vim \
|
||||
wget \
|
||||
curl \
|
||||
xz-utils \
|
||||
net-tools \
|
||||
openssh-client \
|
||||
git \
|
||||
subversion \
|
||||
ntpdate \
|
||||
tzdata \
|
||||
tcl \
|
||||
sudo
|
||||
sudo \
|
||||
bash-completion
|
||||
|
||||
# Install compile tools
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
zlibc \
|
||||
make \
|
||||
libgmp-dev \
|
||||
patch \
|
||||
|
@ -39,7 +41,8 @@ RUN echo "dash dash/sh boolean false" | debconf-set-selections
|
|||
RUN DEBIAN_FRONTEND=noninteractive dpkg-reconfigure dash
|
||||
|
||||
# Install python (v3.7.5)
|
||||
RUN apt install -y --no-install-recommends libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev libgdbm-dev liblzma-dev libreadline-dev \
|
||||
RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||
libgdbm-dev libgdbm-compat-dev liblzma-dev libreadline-dev libsqlite3-dev \
|
||||
&& cd /tmp \
|
||||
&& wget https://github.com/python/cpython/archive/v3.7.5.tar.gz \
|
||||
&& tar -xvf v3.7.5.tar.gz \
|
||||
|
@ -62,12 +65,12 @@ RUN mkdir -pv /root/.pip \
|
|||
&& echo "index-url=http://mirrors.aliyun.com/pypi/simple/" >> /root/.pip/pip.conf
|
||||
|
||||
# Install pip package
|
||||
RUN pip install numpy \
|
||||
&& pip install wheel \
|
||||
&& pip install nose \
|
||||
&& pip install pytest \
|
||||
&& pip install pytest-xdist \
|
||||
&& pip list
|
||||
RUN pip install --no-cache-dir \
|
||||
numpy \
|
||||
wheel \
|
||||
nose \
|
||||
pytest \
|
||||
pytest-xdist
|
||||
|
||||
# Install cmake (v3.14.1)
|
||||
RUN cd /tmp \
|
||||
|
@ -77,4 +80,4 @@ RUN cd /tmp \
|
|||
&& rm -f /tmp/cmake-3.14.1-Linux-x86_64.sh
|
||||
|
||||
# Install MindSpore cuda-10.1 whl package
|
||||
RUN pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/gpu/cuda-10.1/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
RUN pip install --no-cache-dir https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/gpu/cuda-10.1/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
|
|
|
@ -12,20 +12,22 @@ RUN apt update \
|
|||
&& DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
vim \
|
||||
wget \
|
||||
curl \
|
||||
xz-utils \
|
||||
net-tools \
|
||||
openssh-client \
|
||||
git \
|
||||
subversion \
|
||||
ntpdate \
|
||||
tzdata \
|
||||
tcl \
|
||||
sudo
|
||||
sudo \
|
||||
bash-completion
|
||||
|
||||
# Install compile tools
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
zlibc \
|
||||
make \
|
||||
libgmp-dev \
|
||||
patch \
|
||||
|
@ -39,7 +41,8 @@ RUN echo "dash dash/sh boolean false" | debconf-set-selections
|
|||
RUN DEBIAN_FRONTEND=noninteractive dpkg-reconfigure dash
|
||||
|
||||
# Install python (v3.7.5)
|
||||
RUN apt install -y --no-install-recommends libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev libgdbm-dev liblzma-dev libreadline-dev \
|
||||
RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||
libgdbm-dev libgdbm-compat-dev liblzma-dev libreadline-dev libsqlite3-dev \
|
||||
&& cd /tmp \
|
||||
&& wget https://github.com/python/cpython/archive/v3.7.5.tar.gz \
|
||||
&& tar -xvf v3.7.5.tar.gz \
|
||||
|
@ -62,12 +65,12 @@ RUN mkdir -pv /root/.pip \
|
|||
&& echo "index-url=http://mirrors.aliyun.com/pypi/simple/" >> /root/.pip/pip.conf
|
||||
|
||||
# Install pip package
|
||||
RUN pip install numpy \
|
||||
&& pip install wheel \
|
||||
&& pip install nose \
|
||||
&& pip install pytest \
|
||||
&& pip install pytest-xdist \
|
||||
&& pip list
|
||||
RUN pip install --no-cache-dir \
|
||||
numpy \
|
||||
wheel \
|
||||
nose \
|
||||
pytest \
|
||||
pytest-xdist
|
||||
|
||||
# Install cmake (v3.14.1)
|
||||
RUN cd /tmp \
|
||||
|
@ -77,4 +80,4 @@ RUN cd /tmp \
|
|||
&& rm -f /tmp/cmake-3.14.1-Linux-x86_64.sh
|
||||
|
||||
# Install MindSpore cuda-9.2 whl package
|
||||
RUN pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/gpu/cuda-9.2/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
RUN pip install --no-cache-dir https://ms-release.obs.cn-north-4.myhuaweicloud.com/0.1.0-alpha/MindSpore/gpu/cuda-9.2/mindspore-0.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 121 KiB After Width: | Height: | Size: 35 KiB |
|
@ -1,55 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
bert_cfg = edict({
|
||||
'epoch_size': 10,
|
||||
'num_warmup_steps': 0,
|
||||
'start_learning_rate': 1e-4,
|
||||
'end_learning_rate': 1,
|
||||
'decay_steps': 1000,
|
||||
'power': 10.0,
|
||||
'save_checkpoint_steps': 2000,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_prefix': "checkpoint_bert",
|
||||
'DATA_DIR' = "/your/path/examples.tfrecord"
|
||||
'SCHEMA_DIR' = "/your/path/datasetSchema.json"
|
||||
'bert_config': BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
})
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig
|
||||
bert_train_cfg = edict({
|
||||
'epoch_size': 10,
|
||||
'num_warmup_steps': 0,
|
||||
'start_learning_rate': 1e-4,
|
||||
'end_learning_rate': 0.0,
|
||||
'decay_steps': 1000,
|
||||
'power': 10.0,
|
||||
'save_checkpoint_steps': 2000,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_prefix': "checkpoint_bert",
|
||||
# please add your own dataset path
|
||||
'DATA_DIR': "/your/path/examples.tfrecord",
|
||||
# please add your own dataset schema path
|
||||
'SCHEMA_DIR': "/your/path/datasetSchema.json"
|
||||
})
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=16,
|
||||
seq_length=128,
|
||||
vocab_size=21136,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -14,7 +14,8 @@
|
|||
# ============================================================================
|
||||
|
||||
"""
|
||||
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei.
|
||||
NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language
|
||||
model currently based on BERT developed by Huawei.
|
||||
1. Prepare data
|
||||
Following the data preparation as in BERT, run command as below to get dataset for training:
|
||||
python ./create_pretraining_data.py \
|
||||
|
@ -28,36 +29,30 @@ Following the data preparation as in BERT, run command as below to get dataset f
|
|||
--random_seed=12345 \
|
||||
--dupe_factor=5
|
||||
2. Pretrain
|
||||
First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py.
|
||||
First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy import allclose
|
||||
from config import bert_cfg as cfg
|
||||
import mindspore.common.dtype as mstype
|
||||
from config import bert_train_cfg, bert_net_cfg
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore._c_dataengine as deMap
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
|
||||
from mindspore.nn.optim import Lamb
|
||||
from mindspore import log as logger
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
DATA_DIR = [cfg.DATA_DIR]
|
||||
SCHEMA_DIR = cfg.SCHEMA_DIR
|
||||
|
||||
def me_de_train_dataset(batch_size):
|
||||
"""test me de train dataset"""
|
||||
def create_train_dataset(batch_size):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = cfg.epoch_size
|
||||
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"next_sentence_labels", "masked_lm_positions",
|
||||
"masked_lm_ids", "masked_lm_weights"])
|
||||
type_cast_op = deMap.TypeCastOp("int32")
|
||||
repeat_count = bert_train_cfg.epoch_size
|
||||
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
|
@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size):
|
|||
ds = ds.repeat(repeat_count)
|
||||
return ds
|
||||
|
||||
|
||||
def weight_variable(shape):
|
||||
"""weight variable"""
|
||||
np.random.seed(1)
|
||||
ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32)
|
||||
return Tensor(ones)
|
||||
|
||||
|
||||
class ModelCallback(Callback):
|
||||
def __init__(self):
|
||||
super(ModelCallback, self).__init__()
|
||||
self.loss_list = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
self.loss_list.append(cb_params.net_outputs.asnumpy()[0])
|
||||
logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
||||
|
||||
def test_bert_tdt():
|
||||
"""test bert tdt"""
|
||||
def train_bert():
|
||||
"""train bert"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_context(enable_task_sink=True)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
parallel_callback = ModelCallback()
|
||||
ds = me_de_train_dataset(cfg.bert_config.batch_size)
|
||||
config = cfg.bert_config
|
||||
netwithloss = BertNetworkWithLoss(config, True)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate,
|
||||
end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False)
|
||||
ds = create_train_dataset(bert_net_cfg.batch_size)
|
||||
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,
|
||||
start_learning_rate=bert_train_cfg.start_learning_rate,
|
||||
end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power,
|
||||
warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False)
|
||||
netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer)
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck)
|
||||
model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bert_tdt()
|
||||
train_bert()
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
alexnet_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'learning_rate': 0.002,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 1,
|
||||
'batch_size': 32,
|
||||
'buffer_size': 1000,
|
||||
'image_height': 227,
|
||||
'image_width': 227,
|
||||
'save_checkpoint_steps': 1562,
|
||||
'keep_checkpoint_max': 10,
|
||||
})
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Produce the dataset
|
||||
"""
|
||||
|
||||
from config import alexnet_cfg as cfg
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1, status="train"):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
cifar_ds = ds.Cifar10Dataset(data_path)
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
if status == "train":
|
||||
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
|
||||
random_horizontal_op = CV.RandomHorizontalFlip()
|
||||
channel_swap_op = CV.HWC2CHW()
|
||||
typecast_op = C.TypeCast(mstype.int32)
|
||||
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op)
|
||||
if status == "train":
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op)
|
||||
|
||||
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
|
||||
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
|
||||
cifar_ds = cifar_ds.repeat(repeat_size)
|
||||
return cifar_ds
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## eval alexnet example ########################
|
||||
eval alexnet according to model file:
|
||||
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from config import alexnet_cfg as cfg
|
||||
from dataset import create_dataset
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.model_zoo.alexnet import AlexNet
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
|
||||
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
repeat_size = cfg.epoch_size
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset(args.data_path,
|
||||
cfg.batch_size,
|
||||
1,
|
||||
"test")
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train alexnet example ########################
|
||||
train alexnet and get network model files(.ckpt) :
|
||||
python train.py --data_path /YourDataPath
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from config import alexnet_cfg as cfg
|
||||
from dataset import create_dataset
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.model_zoo.alexnet import AlexNet
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
|
||||
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
ds_train = create_dataset(args.data_path,
|
||||
cfg.batch_size,
|
||||
cfg.epoch_size,
|
||||
"train")
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
|
||||
model.train(cfg.epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
|
@ -1,125 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train and test lenet example ########################
|
||||
1. train lenet and get network model files(.ckpt) :
|
||||
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data
|
||||
|
||||
2. test lenet according to model file:
|
||||
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data
|
||||
--mode test --ckpt_path checkpoint_lenet_1-1_1875.ckpt
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
from config import mnist_cfg as cfg
|
||||
|
||||
import mindspore.dataengine as de
|
||||
import mindspore.nn as nn
|
||||
from mindspore.model_zoo.lenet import LeNet5
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train import Model
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.transforms.c_transforms as C
|
||||
from mindspore.transforms import Inter
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
Define loss for network
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss = self.cross_entropy(logits, label)[0]
|
||||
loss = self.mean(loss, (-1,))
|
||||
return loss
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
ds1 = de.MnistDataset(data_path)
|
||||
|
||||
# apply map operations on images
|
||||
ds1 = ds1.map(input_columns="label", operations=C.TypeCast(mstype.int32))
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Resize((cfg.image_height, cfg.image_width),
|
||||
interpolation=Inter.LINEAR),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1.0 / 255.0, 0.0),
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds1 = ds1.map(input_columns="image", operations=C.HWC2CHW(), num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
ds1 = ds1.shuffle(buffer_size=cfg.buffer_size) # 10000 as in LeNet train script
|
||||
ds1 = ds1.batch(batch_size, drop_remainder=True)
|
||||
ds1 = ds1.repeat(repeat_size)
|
||||
|
||||
return ds1
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
|
||||
help='implement phase, set to train or test')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
network.set_train()
|
||||
# net_loss = nn.SoftmaxCrossEntropyWithLogits() # support this loss soon
|
||||
net_loss = CrossEntropyLoss()
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
if args.mode == 'train': # train
|
||||
ds = create_dataset(os.path.join(args.data_path, args.mode), batch_size=cfg.batch_size,
|
||||
repeat_size=cfg.epoch_size)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
|
||||
elif args.mode == 'test': # test
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
else:
|
||||
raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode))
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
mnist_cfg = edict({
|
||||
|
@ -23,7 +24,6 @@ mnist_cfg = edict({
|
|||
'momentum': 0.9,
|
||||
'epoch_size': 1,
|
||||
'batch_size': 32,
|
||||
'repeat_size': 1,
|
||||
'buffer_size': 1000,
|
||||
'image_height': 32,
|
||||
'image_width': 32,
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Produce the dataset
|
||||
"""
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore.dataset.transforms.vision import Inter
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
hwc2chw_op = CV.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## eval lenet example ########################
|
||||
eval lenet according to model file:
|
||||
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from dataset import create_dataset
|
||||
from config import mnist_cfg as cfg
|
||||
import mindspore.nn as nn
|
||||
from mindspore.model_zoo.lenet import LeNet5
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
repeat_size = cfg.epoch_size
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset(os.path.join(args.data_path, "test"),
|
||||
cfg.batch_size,
|
||||
1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train lenet example ########################
|
||||
train lenet and get network model files(.ckpt) :
|
||||
python train.py --data_path /YourDataPath
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from config import mnist_cfg as cfg
|
||||
from dataset import create_dataset
|
||||
import mindspore.nn as nn
|
||||
from mindspore.model_zoo.lenet import LeNet5
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||
cfg.batch_size,
|
||||
cfg.epoch_size)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cifar_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr_init': 0.05,
|
||||
'batch_size': 64,
|
||||
'epoch_size': 70,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'buffer_size': 10,
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'keep_checkpoint_max': 10
|
||||
})
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
import mindspore.common.dtype as mstype
|
||||
from config import cifar_cfg as cfg
|
||||
|
||||
def create_dataset(data_home, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
ds.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
if not training:
|
||||
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||
data_set = ds.Cifar10Dataset(data_dir)
|
||||
resize_height = cfg.image_height
|
||||
resize_width = cfg.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||
random_horizontal_op = vision.RandomHorizontalFlip()
|
||||
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||
rescale_op = vision.Rescale(rescale, shift)
|
||||
normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
|
||||
changeswap_op = vision.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
c_trans = []
|
||||
if training:
|
||||
c_trans = [random_crop_op, random_horizontal_op]
|
||||
c_trans += [resize_op, rescale_op, normalize_op,
|
||||
changeswap_op]
|
||||
|
||||
# apply map operations on images
|
||||
data_set = data_set.map(input_columns="label", operations=type_cast_op)
|
||||
data_set = data_set.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply repeat operations
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
# apply shuffle operations
|
||||
data_set = data_set.shuffle(buffer_size=10)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
|
||||
|
||||
return data_set
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############test vgg16 example on cifar10#################
|
||||
python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
||||
"""
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.vgg import vgg16
|
||||
from config import cifar_cfg as cfg
|
||||
import dataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='checkpoint file path.')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
dataset = dataset.create_dataset(args_opt.data_path, 1, False)
|
||||
res = model.eval(dataset)
|
||||
print("result: ", res)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################train vgg16 example on cifar10########################
|
||||
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
||||
"""
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.model_zoo.vgg import vgg16
|
||||
import dataset
|
||||
from config import cifar_cfg as cfg
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
|
||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr_each_step.append(lr_max)
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr_each_step.append(lr_max * 0.1)
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr_each_step.append(lr_max * 0.01)
|
||||
else:
|
||||
lr_each_step.append(lr_max * 0.001)
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default='./cifar', help='path where the dataset is saved')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
context.set_context(enable_mem_reuse=True, enable_hccl=False)
|
||||
|
||||
net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
|
@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18:
|
|||
img_shape = [352, 640]
|
||||
feature_shape = [32, 3, 352, 640]
|
||||
num_classes = 80
|
||||
nms_max_num = 50
|
||||
|
||||
backbone_input_shape = [64, 64, 128, 256]
|
||||
backbone_shape = [64, 128, 256, 512]
|
||||
|
@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18:
|
|||
backbone_stride = [1, 2, 2, 2]
|
||||
|
||||
ignore_threshold = 0.5
|
||||
obj_threshold = 0.3
|
||||
nms_threshold = 0.4
|
||||
|
||||
anchor_scales = [(10, 13),
|
||||
(16, 30),
|
||||
|
|
|
@ -16,16 +16,14 @@
|
|||
"""YOLOv3 dataset"""
|
||||
from __future__ import division
|
||||
|
||||
import abc
|
||||
import io
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
|
||||
import mindspore.dataset as de
|
||||
from mindspore.mindrecord import FileWriter
|
||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
iter_cnt = 0
|
||||
|
@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training):
|
|||
|
||||
return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
|
||||
|
||||
def _infer_data(img_data, input_shape, box):
|
||||
w, h = img_data.size
|
||||
input_h, input_w = input_shape
|
||||
scale = min(float(input_w) / float(w), float(input_h) / float(h))
|
||||
nw = int(w * scale)
|
||||
nh = int(h * scale)
|
||||
img_data = img_data.resize((nw, nh), Image.BICUBIC)
|
||||
|
||||
new_image = np.zeros((input_h, input_w, 3), np.float32)
|
||||
new_image.fill(128)
|
||||
img_data = np.array(img_data)
|
||||
if len(img_data.shape) == 2:
|
||||
img_data = np.expand_dims(img_data, axis=-1)
|
||||
img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
|
||||
|
||||
dh = int((input_h - nh) / 2)
|
||||
dw = int((input_w - nw) / 2)
|
||||
new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
|
||||
new_image /= 255.
|
||||
new_image = np.transpose(new_image, (2, 0, 1))
|
||||
new_image = np.expand_dims(new_image, 0)
|
||||
return new_image, np.array([h, w], np.float32), box
|
||||
|
||||
def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
|
||||
"""Data augmentation function."""
|
||||
if not isinstance(image, Image.Image):
|
||||
|
@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training):
|
|||
h, w = image_size
|
||||
|
||||
if not is_training:
|
||||
image = image.resize((w, h), Image.BICUBIC)
|
||||
image_data = np.array(image) / 255.
|
||||
if len(image_data.shape) == 2:
|
||||
image_data = np.expand_dims(image_data, axis=-1)
|
||||
image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
|
||||
image_data = image_data.astype(np.float32)
|
||||
|
||||
# correct boxes
|
||||
box_data = np.zeros((max_boxes, 5))
|
||||
if len(box) >= 1:
|
||||
np.random.shuffle(box)
|
||||
if len(box) > max_boxes:
|
||||
box = box[:max_boxes]
|
||||
# xmin ymin xmax ymax
|
||||
box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
|
||||
box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
|
||||
box_data[:len(box)] = box
|
||||
else:
|
||||
image_data, box_data = None, None
|
||||
|
||||
# preprocess bounding boxes
|
||||
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
|
||||
_preprocess_true_boxes(box_data, anchors, image_size)
|
||||
|
||||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
return _infer_data(image, image_size, box)
|
||||
|
||||
flip = _rand() < .5
|
||||
# correct boxes
|
||||
|
@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training):
|
|||
return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
|
||||
ori_image_shape, gt_box1, gt_box2, gt_box3
|
||||
|
||||
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
|
||||
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
|
||||
if is_training:
|
||||
images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
|
||||
return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
|
||||
|
||||
images, shape, anno = _data_aug(image, box, is_training)
|
||||
return images, shape, anno
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Annotation parser."""
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
|
@ -248,142 +248,71 @@ def anno_parser(annos_str):
|
|||
return annos
|
||||
|
||||
|
||||
def expand_path(path):
|
||||
"""Get file list from path."""
|
||||
files = []
|
||||
if os.path.isdir(path):
|
||||
for file in os.listdir(path):
|
||||
if os.path.isfile(os.path.join(path, file)):
|
||||
files.append(file)
|
||||
else:
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
return files
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
if os.path.isfile(os.path.join(image_dir, file_name)):
|
||||
image_anno_dict[file_name] = anno_parser(line_split[1:])
|
||||
image_files.append(file_name)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Read image with PIL."""
|
||||
with open(img_path, "rb") as f:
|
||||
img = f.read()
|
||||
data = io.BytesIO(img)
|
||||
img = Image.open(data)
|
||||
return np.array(img)
|
||||
def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
|
||||
"""Create MindRecord file by image_dir and anno_path."""
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
|
||||
|
||||
yolo_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int64", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(yolo_json, "yolo_json")
|
||||
|
||||
for image_name in image_files:
|
||||
image_path = os.path.join(image_dir, image_name)
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name])
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
class BaseDataset():
|
||||
"""BaseDataset for GeneratorDataset iterator."""
|
||||
def __init__(self, image_dir, anno_path):
|
||||
self.image_dir = image_dir
|
||||
self.anno_path = anno_path
|
||||
self.cur_index = 0
|
||||
self.samples = []
|
||||
self.image_anno_dict = {}
|
||||
self._load_samples()
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.samples[item]
|
||||
return self._next_data(sample, self.image_dir, self.image_anno_dict)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
@staticmethod
|
||||
def _next_data(sample, image_dir, image_anno_dict):
|
||||
"""Get next data."""
|
||||
image = read_image(os.path.join(image_dir, sample))
|
||||
annos = image_anno_dict[sample]
|
||||
return [np.array(image), np.array(annos)]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_samples(self):
|
||||
"""Base load samples."""
|
||||
|
||||
|
||||
class YoloDataset(BaseDataset):
|
||||
"""YoloDataset for GeneratorDataset iterator."""
|
||||
def _load_samples(self):
|
||||
"""Load samples."""
|
||||
image_files_raw = expand_path(self.image_dir)
|
||||
self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
|
||||
self.dataset_size = len(self.samples)
|
||||
if self.dataset_size == 0:
|
||||
raise RuntimeError("Valid dataset is none!")
|
||||
|
||||
def _filter_valid_data(self, anno_path, image_files_raw):
|
||||
"""Filter valid data."""
|
||||
image_files = []
|
||||
anno_dict = {}
|
||||
print("Start filter valid data.")
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8")
|
||||
line_split = str(line_str).split(' ')
|
||||
anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
|
||||
anno_set = set(anno_dict.keys())
|
||||
image_set = set(image_files_raw)
|
||||
for image_file in (anno_set & image_set):
|
||||
image_files.append(image_file)
|
||||
self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
|
||||
image_files.sort()
|
||||
print("Filter valid data done!")
|
||||
return image_files
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
"""DistributedSampler for YOLOv3"""
|
||||
def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank % num_replicas
|
||||
self.epoch = 0
|
||||
self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=8):
|
||||
"""Creatr YOLOv3 dataset with GeneratorDataset."""
|
||||
yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
|
||||
distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
||||
ds.set_dataset_size(len(distributed_sampler))
|
||||
"""Creatr YOLOv3 dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
||||
num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(input_columns=["image"], operations=decode)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
hwc_to_chw = P.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.shuffle(buffer_size=256)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
if is_training:
|
||||
hwc_to_chw = P.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
|
||||
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.shuffle(buffer_size=256)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
else:
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "annotation"],
|
||||
columns_order=["image", "image_shape", "annotation"],
|
||||
operations=compose_map_func, num_parallel_workers=num_parallel_workers)
|
||||
return ds
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for yolo_v3"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithEval
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
from util import metrics
|
||||
|
||||
def yolo_eval(dataset_path, ckpt_path):
|
||||
"""Yolov3 evaluation."""
|
||||
|
||||
ds = create_yolo_dataset(dataset_path, is_training=False)
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
net = yolov3_resnet18(config)
|
||||
eval_net = YoloWithEval(net, config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
|
||||
eval_net.set_train(False)
|
||||
i = 1.
|
||||
total = ds.get_dataset_size()
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator():
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
annotation = data['annotation']
|
||||
|
||||
eval_net.set_train(False)
|
||||
output = eval_net(Tensor(img_np), Tensor(image_shape))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"annotation": annotation})
|
||||
percent = round(i / total * 100, 2)
|
||||
|
||||
print(' %s [%d/%d]' % (str(percent) + '%', i, total), end='\r')
|
||||
i += 1
|
||||
print(' %s [%d/%d] cost %d ms' % (str(100.0) + '%', total, total, int((time.time() - start) * 1000)), end='\n')
|
||||
|
||||
precisions, recalls = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
for i in range(config.num_classes):
|
||||
print("class {} precision is {:.2f}%, recall is {:.2f}%".format(i, precisions[i] * 100, recalls[i] * 100))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Yolov3 evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_eval",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
|
||||
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than image_dir and anno_path. Default is ./Mindrecord_eval")
|
||||
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the image_dir "
|
||||
"and the relative path in anno_path.")
|
||||
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
|
||||
parser.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
|
||||
enable_auto_mixed_precision=False)
|
||||
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# and the file name is yolo.mindrecord0, 1, ... file_num.
|
||||
if not os.path.isdir(args_opt.mindrecord_dir):
|
||||
os.makedirs(args_opt.mindrecord_dir)
|
||||
|
||||
prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
|
||||
print("Create Mindrecord")
|
||||
data_to_mindrecord_byte_image(args_opt.image_dir,
|
||||
args_opt.anno_path,
|
||||
args_opt.mindrecord_dir,
|
||||
prefix=prefix,
|
||||
file_num=8)
|
||||
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits")
|
||||
print("Start Eval!")
|
||||
yolo_eval(mindrecord_file, args_opt.ckpt_path)
|
|
@ -14,17 +14,26 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
MINDRECORD_DIR=$3
|
||||
IMAGE_DIR=$4
|
||||
ANNO_PATH=$5
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \
|
||||
--anno_path=$ANNO_PATH
|
||||
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$6
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
IMAGE_DIR=$3
|
||||
ANNO_PATH=$4
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$5
|
||||
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
@ -40,6 +49,7 @@ do
|
|||
--distribute=1 \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--mindrecord_dir=$MINDRECORD_DIR \
|
||||
--image_dir=$IMAGE_DIR \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--anno_path=$ANNO_PATH > log.txt 2>&1 &
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python eval.py --device_id=$1 --ckpt_path=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
|
||||
echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
|
||||
echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
python train.py --device_id=$1 --epoch_size=$2 --image_dir=$3 --anno_path=$4
|
||||
python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5
|
||||
|
|
|
@ -16,26 +16,30 @@
|
|||
"""
|
||||
######################## train YOLOv3 example ########################
|
||||
train YOLOv3 and get network model files(.ckpt) :
|
||||
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
|
||||
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
|
||||
|
||||
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
|
||||
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
from mindspore.model_zoo.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from dataset import create_yolo_dataset
|
||||
from dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||
"""Set learning rate"""
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
lr = learning_rate
|
||||
for i in range(global_step):
|
||||
|
@ -57,7 +61,9 @@ def init_net_param(net, init='ones'):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="YOLOv3")
|
||||
parser = argparse.ArgumentParser(description="YOLOv3 train")
|
||||
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
|
||||
"Mindrecord, default is false.")
|
||||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
|
@ -67,12 +73,19 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Dataset image dir.")
|
||||
parser.add_argument("--anno_path", type=str, required=True, help="Dataset anno path.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
|
||||
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than image_dir and anno_path. Default is ./Mindrecord_train")
|
||||
parser.add_argument("--image_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the image_dir "
|
||||
"and the relative path in anno_path")
|
||||
parser.add_argument("--anno_path", type=str, default="", help="Annotation path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True,
|
||||
enable_auto_mixed_precision=False)
|
||||
if args_opt.distribute:
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
|
@ -80,36 +93,65 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = args_opt.device_id
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
context.set_context(enable_hccl=False)
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
dataset = create_yolo_dataset(args_opt.image_dir, args_opt.anno_path, repeat_num=args_opt.epoch_size,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
||||
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
||||
init_net_param(net, "XavierUniform")
|
||||
print("Start create dataset!")
|
||||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# and the file name is yolo.mindrecord0, 1, ... file_num.
|
||||
if not os.path.isdir(args_opt.mindrecord_dir):
|
||||
os.makedirs(args_opt.mindrecord_dir)
|
||||
|
||||
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
|
||||
decay_step=1000, decay_rate=0.95))
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
prefix = "yolo.mindrecord"
|
||||
mindrecord_file = os.path.join(args_opt.mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if os.path.isdir(args_opt.image_dir) and os.path.exists(args_opt.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image(args_opt.image_dir,
|
||||
args_opt.anno_path,
|
||||
args_opt.mindrecord_dir,
|
||||
prefix=prefix,
|
||||
file_num=8)
|
||||
print("Create Mindrecord Done, at {}".format(args_opt.mindrecord_dir))
|
||||
else:
|
||||
print("image_dir or anno_path not exits.")
|
||||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "graph":
|
||||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
|
||||
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
||||
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
||||
init_net_param(net, "XavierUniform")
|
||||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config)
|
||||
|
||||
lr = Tensor(get_lr(learning_rate=0.001, start_step=0, global_step=args_opt.epoch_size * dataset_size,
|
||||
decay_step=1000, decay_rate=0.95))
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "graph":
|
||||
print("In graph mode, one epoch return a loss.")
|
||||
dataset_sink_mode = True
|
||||
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from config import ConfigYOLOV3ResNet18
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
||||
"""Calculate iou of predicted bbox and ground truth."""
|
||||
x1 = bbox_pred[0]
|
||||
y1 = bbox_pred[1]
|
||||
width1 = bbox_pred[2] - bbox_pred[0]
|
||||
height1 = bbox_pred[3] - bbox_pred[1]
|
||||
|
||||
x2 = bbox_ground[0]
|
||||
y2 = bbox_ground[1]
|
||||
width2 = bbox_ground[2] - bbox_ground[0]
|
||||
height2 = bbox_ground[3] - bbox_ground[1]
|
||||
|
||||
endx = max(x1 + width1, x2 + width2)
|
||||
startx = min(x1, x2)
|
||||
width = width1 + width2 - (endx - startx)
|
||||
|
||||
endy = max(y1 + height1, y2 + height2)
|
||||
starty = min(y1, y2)
|
||||
height = height1 + height2 - (endy - starty)
|
||||
|
||||
if width <= 0 or height <= 0:
|
||||
iou = 0
|
||||
else:
|
||||
area = width * height
|
||||
area1 = width1 * height1
|
||||
area2 = width2 * height2
|
||||
iou = area * 1. / (area1 + area2 - area)
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
x1 = all_boxes[:, 0]
|
||||
y1 = all_boxes[:, 1]
|
||||
x2 = all_boxes[:, 2]
|
||||
y2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate precision and recall of predicted bboxes."""
|
||||
config = ConfigYOLOV3ResNet18()
|
||||
num_classes = config.num_classes
|
||||
count_corrects = [1e-6 for _ in range(num_classes)]
|
||||
count_grounds = [1e-6 for _ in range(num_classes)]
|
||||
count_preds = [1e-6 for _ in range(num_classes)]
|
||||
|
||||
for i, sample in enumerate(pred_data):
|
||||
gt_anno = sample["annotation"]
|
||||
box_scores = sample['box_scores']
|
||||
boxes = sample['boxes']
|
||||
mask = box_scores >= config.obj_threshold
|
||||
boxes_ = []
|
||||
scores_ = []
|
||||
classes_ = []
|
||||
max_boxes = config.nms_max_num
|
||||
for c in range(num_classes):
|
||||
class_boxes = np.reshape(boxes, [-1, 4])[np.reshape(mask[:, c], [-1])]
|
||||
class_box_scores = np.reshape(box_scores[:, c], [-1])[np.reshape(mask[:, c], [-1])]
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_threshold, max_boxes)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
classes = np.ones_like(class_box_scores, 'int32') * c
|
||||
boxes_.append(class_boxes)
|
||||
scores_.append(class_box_scores)
|
||||
classes_.append(classes)
|
||||
|
||||
boxes = np.concatenate(boxes_, axis=0)
|
||||
classes = np.concatenate(classes_, axis=0)
|
||||
|
||||
|
||||
# metric
|
||||
count_correct = [1e-6 for _ in range(num_classes)]
|
||||
count_ground = [1e-6 for _ in range(num_classes)]
|
||||
count_pred = [1e-6 for _ in range(num_classes)]
|
||||
|
||||
for anno in gt_anno:
|
||||
count_ground[anno[4]] += 1
|
||||
|
||||
for box_index, box in enumerate(boxes):
|
||||
bbox_pred = [box[1], box[0], box[3], box[2]]
|
||||
count_pred[classes[box_index]] += 1
|
||||
|
||||
for anno in gt_anno:
|
||||
class_ground = anno[4]
|
||||
|
||||
if classes[box_index] == class_ground:
|
||||
iou = calc_iou(bbox_pred, anno)
|
||||
if iou >= 0.5:
|
||||
count_correct[class_ground] += 1
|
||||
break
|
||||
|
||||
count_corrects = [count_corrects[i] + count_correct[i] for i in range(num_classes)]
|
||||
count_preds = [count_preds[i] + count_pred[i] for i in range(num_classes)]
|
||||
count_grounds = [count_grounds[i] + count_ground[i] for i in range(num_classes)]
|
||||
|
||||
precision = np.array([count_corrects[ix] / count_preds[ix] for ix in range(num_classes)])
|
||||
recall = np.array([count_corrects[ix] / count_grounds[ix] for ix in range(num_classes)])
|
||||
return precision, recall
|
|
@ -1 +1 @@
|
|||
Subproject commit 092c7a1f6548cac7d40e677af3498c3c49ea2bfd
|
||||
Subproject commit 32f06965993ca25ea936239d284e992c5512d2ff
|
|
@ -15,6 +15,7 @@
|
|||
"""Check parameters."""
|
||||
import re
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from itertools import repeat
|
||||
from collections import Iterable
|
||||
|
||||
|
@ -93,8 +94,131 @@ rel_strs = {
|
|||
}
|
||||
|
||||
|
||||
class Validator:
|
||||
"""validator for checking input parameters"""
|
||||
|
||||
@staticmethod
|
||||
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
|
||||
"""
|
||||
Method for judging relation between two int values or list/tuple made up of ints.
|
||||
|
||||
This method is not suitable for judging relation between floats, since it does not consider float error.
|
||||
"""
|
||||
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
if not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
|
||||
msg_prefix = f'For {prim_name} the' if prim_name else "The"
|
||||
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
|
||||
|
||||
@staticmethod
|
||||
def check_integer(arg_name, arg_value, value, rel, prim_name):
|
||||
"""Integer value judgment."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
|
||||
if type_mismatch or not rel_fn(arg_value, value):
|
||||
rel_str = Rel.get_strs(rel).format(value)
|
||||
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
|
||||
f' but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
|
||||
"""Method for checking whether an int value is in some range."""
|
||||
rel_fn = Rel.get_fns(rel)
|
||||
type_mismatch = not isinstance(arg_value, int)
|
||||
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
|
||||
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
|
||||
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
|
||||
f' but got {arg_value}.')
|
||||
return arg_value
|
||||
|
||||
@staticmethod
|
||||
def check_subclass(arg_name, type_, template_type, prim_name):
|
||||
"""Check whether some type is sublcass of another type"""
|
||||
if not isinstance(template_type, Iterable):
|
||||
template_type = (template_type,)
|
||||
if not any([mstype.issubclass_(type_, x) for x in template_type]):
|
||||
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
|
||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
|
||||
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
|
||||
|
||||
@staticmethod
|
||||
def check_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the element types of input tensors are the same."""
|
||||
def _check_tensor_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
|
||||
elem_type = arg_val.element_type()
|
||||
if not elem_type in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {elem_type}.')
|
||||
return (arg_key, elem_type)
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
arg1_name, arg1_type = arg1
|
||||
arg2_name, arg2_type = arg2
|
||||
if arg1_type != arg2_type:
|
||||
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
|
||||
elem_types = map(_check_tensor_type, args.items())
|
||||
reduce(_check_types_same, elem_types)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
|
||||
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
|
||||
def _check_argument_type(arg):
|
||||
arg_key, arg_val = arg
|
||||
if isinstance(arg_val, type(mstype.tensor)):
|
||||
arg_val = arg_val.element_type()
|
||||
if not arg_val in valid_values:
|
||||
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
|
||||
f' but `{arg_key}` is {arg_val}.')
|
||||
return arg
|
||||
|
||||
def _check_types_same(arg1, arg2):
|
||||
arg1_name, arg1_type = arg1
|
||||
arg2_name, arg2_type = arg2
|
||||
excp_flag = False
|
||||
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
|
||||
arg1_type = arg1_type.element_type()
|
||||
arg2_type = arg2_type.element_type()
|
||||
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
|
||||
pass
|
||||
else:
|
||||
excp_flag = True
|
||||
|
||||
if excp_flag or arg1_type != arg2_type:
|
||||
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
|
||||
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
|
||||
return arg1
|
||||
reduce(_check_types_same, map(_check_argument_type, args.items()))
|
||||
|
||||
@staticmethod
|
||||
def check_value_type(arg_name, arg_value, valid_types, prim_name):
|
||||
"""Check whether a values is instance of some types."""
|
||||
def raise_error_msg():
|
||||
"""func for raising error message when check failed"""
|
||||
type_names = [t.__name__ for t in valid_types]
|
||||
num_types = len(valid_types)
|
||||
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
|
||||
f'{"one of " if num_types > 1 else ""}'
|
||||
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
|
||||
|
||||
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
|
||||
# `check_value_type('x', True, [bool, int])` will check pass
|
||||
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
|
||||
raise_error_msg()
|
||||
if isinstance(arg_value, tuple(valid_types)):
|
||||
return arg_value
|
||||
raise_error_msg()
|
||||
|
||||
|
||||
class ParamValidator:
|
||||
"""Parameter validator."""
|
||||
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
|
||||
|
||||
@staticmethod
|
||||
def equal(arg_name, arg_value, cond_str, cond):
|
||||
|
|
|
@ -50,7 +50,7 @@ def get_build_in_impl_path():
|
|||
tbe_impl_path = os.environ.get("TBE_IMPL_PATH")
|
||||
if tbe_impl_path is None:
|
||||
default_install_path = '/usr/local/HiAI/runtime/ops/op_impl/built-in/ai_core/tbe/'
|
||||
backup_install_path = '/usr/local/Ascend/Ascend/opp/op_impl/built-in/ai_core/tbe/'
|
||||
backup_install_path = '/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/'
|
||||
if os.path.exists(default_install_path):
|
||||
tbe_impl_path = default_install_path
|
||||
elif os.path.exists(backup_install_path):
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""tbe process"""
|
||||
import sys
|
||||
import os
|
||||
from .common import get_args, get_build_in_impl_path, TBEException
|
||||
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
|
||||
def _op_select_format(kernel_info):
|
||||
"""
|
||||
call op's op_select_format to get op supported format
|
||||
|
||||
Args:
|
||||
kernel_info (dict): kernel info load by json string
|
||||
|
||||
Returns:
|
||||
op supported format
|
||||
"""
|
||||
try:
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
if impl_path not in sys.path:
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "op_select_format"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "op_select_format", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
|
||||
except Exception as e:
|
||||
raise TBEException(str(e))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _check_supported(kernel_info):
|
||||
"""
|
||||
call op's check_supported to check supported or not
|
||||
|
||||
Args:
|
||||
kernel_info (dict): kernel info load by json string
|
||||
|
||||
Returns:
|
||||
bool: check result, true or false
|
||||
"""
|
||||
try:
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
if impl_path not in sys.path:
|
||||
sys.path.insert(0, impl_path)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "check_supported"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "check_supported", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
|
||||
except Exception as e:
|
||||
raise TBEException(str(e))
|
||||
|
||||
return ret
|
|
@ -19,10 +19,8 @@ import subprocess
|
|||
import sys
|
||||
import os
|
||||
import json
|
||||
from .common import check_kernel_info, get_args, get_build_in_impl_path
|
||||
|
||||
build_in_impl_path = get_build_in_impl_path()
|
||||
|
||||
from .common import check_kernel_info, TBEException
|
||||
from .helper import _op_select_format, _check_supported
|
||||
|
||||
def create_tbe_parallel_compiler():
|
||||
"""
|
||||
|
@ -41,40 +39,17 @@ def op_select_format(op_json: str):
|
|||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
op supported format
|
||||
op supported format or exception message
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
try:
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
ret = _op_select_format(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
except TBEException as e:
|
||||
return "TBEException: " + str(e)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "op_select_format"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "op_select_format", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -86,40 +61,18 @@ def check_supported(op_json: str):
|
|||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
true or false
|
||||
bool: check result, true or false
|
||||
str: exception message when catch an Exception
|
||||
"""
|
||||
ret = ""
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
try:
|
||||
kernel_info = json.loads(op_json)
|
||||
check_kernel_info(kernel_info)
|
||||
ret = _check_supported(kernel_info)
|
||||
|
||||
# import module
|
||||
op_name = kernel_info['op_info']['name']
|
||||
impl_path = build_in_impl_path
|
||||
custom_flag = False
|
||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||
op_impl_path = os.path.realpath(kernel_info['impl_path'])
|
||||
if os.path.isfile(op_impl_path):
|
||||
path, file_name = os.path.split(op_impl_path)
|
||||
op_name, _ = os.path.splitext(file_name)
|
||||
impl_path = path
|
||||
custom_flag = True
|
||||
sys.path.insert(0, impl_path)
|
||||
except TBEException as e:
|
||||
return "TBEException: " + str(e)
|
||||
|
||||
if custom_flag:
|
||||
op_module = __import__(op_name)
|
||||
else:
|
||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||
# get function
|
||||
if not hasattr(op_module, "check_supported"):
|
||||
return ""
|
||||
op_func = getattr(op_module, "check_supported", None)
|
||||
|
||||
# call function
|
||||
inputs_args = get_args(kernel_info['op_info'], 'inputs')
|
||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||
kernel_name = kernel_info['op_info']['kernel_name']
|
||||
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -149,12 +102,12 @@ class CompilerPool:
|
|||
"""compiler pool"""
|
||||
|
||||
def __init__(self):
|
||||
processes = multiprocessing.cpu_count()
|
||||
self.__processe_num = multiprocessing.cpu_count()
|
||||
# max_processes_num: Set the maximum number of concurrent processes for compiler
|
||||
max_processes_num = 16
|
||||
if processes > max_processes_num:
|
||||
processes = max_processes_num
|
||||
self.__pool = multiprocessing.Pool(processes=processes)
|
||||
if self.__processe_num > max_processes_num:
|
||||
self.__processe_num = max_processes_num
|
||||
self.__pool = None
|
||||
self.__next_task_id = 1
|
||||
self.__running_tasks = []
|
||||
|
||||
|
@ -165,11 +118,10 @@ class CompilerPool:
|
|||
del self.__pool
|
||||
|
||||
def exit(self):
|
||||
return
|
||||
# self.__pool.terminate()
|
||||
# self.__pool.join()
|
||||
# if self.__pool is not None:
|
||||
# del self.__pool
|
||||
if self.__pool is not None:
|
||||
self.__pool.terminate()
|
||||
self.__pool.join()
|
||||
del self.__pool
|
||||
|
||||
def start_compile_op(self, op_json):
|
||||
"""
|
||||
|
@ -183,6 +135,8 @@ class CompilerPool:
|
|||
"""
|
||||
task_id = self.__next_task_id
|
||||
self.__next_task_id = self.__next_task_id + 1
|
||||
if self.__pool is None:
|
||||
self.__pool = multiprocessing.Pool(processes=self.__processe_num)
|
||||
task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,))
|
||||
self.__running_tasks.append((task_id, task_future))
|
||||
return task_id
|
||||
|
|
|
@ -92,16 +92,16 @@ convert_object_map = {
|
|||
T.and_: multitype_ops.logical_and,
|
||||
T.or_: multitype_ops.logical_or,
|
||||
T.xor: NO_IMPLEMENT,
|
||||
T.pos: F.scalar_uadd,
|
||||
T.pos: multitype_ops.uadd,
|
||||
T.neg: multitype_ops.negative,
|
||||
T.invert: NO_IMPLEMENT,
|
||||
T.not_: F.bool_not,
|
||||
T.not_: multitype_ops.logical_not,
|
||||
T.eq: multitype_ops.equal,
|
||||
T.ne: F.scalar_ne,
|
||||
T.ne: multitype_ops.not_equal,
|
||||
T.lt: multitype_ops.less,
|
||||
T.gt: F.scalar_gt,
|
||||
T.gt: multitype_ops.greater,
|
||||
T.le: multitype_ops.less_equal,
|
||||
T.ge: F.scalar_ge,
|
||||
T.ge: multitype_ops.greater_equal,
|
||||
T.is_: F.is_,
|
||||
T.is_not: F.is_not,
|
||||
T.contains: NO_IMPLEMENT,
|
||||
|
|
|
@ -34,6 +34,8 @@ if(ENABLE_GPU)
|
|||
"device/gpu/*.cu"
|
||||
"kernel/gpu/*.cu"
|
||||
"kernel/akg/gpu/*.cc"
|
||||
"kernel/akg/akgkernelbuild.cc"
|
||||
"kernel/akg/akg_kernel_attrs_process.cc"
|
||||
)
|
||||
file(GLOB_RECURSE GPU_KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"kernel/gpu/*.cc"
|
||||
|
@ -85,7 +87,22 @@ ms_build_flatbuffers("${FLATBUFFER_IN}" "${FLATBUFFER_IN}" GENERATED_OUTPUT_DIR
|
|||
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"ir/*.cc"
|
||||
"ir/dtype/*.cc"
|
||||
"utils/*.cc"
|
||||
"utils/context/ms_context.cc"
|
||||
"utils/symbolic.cc"
|
||||
"utils/tensorprint_utils.cc"
|
||||
"utils/convert_utils.cc"
|
||||
"utils/graph_utils.cc"
|
||||
"utils/misc.cc"
|
||||
"utils/callbacks.cc"
|
||||
"utils/profile.cc"
|
||||
"utils/base_ref.cc"
|
||||
"utils/summary/event_writer.cc"
|
||||
"utils/log_adapter.cc"
|
||||
"utils/comm_manager.cc"
|
||||
"utils/any.cc"
|
||||
"utils/config_manager.cc"
|
||||
"utils/system/file_system.cc"
|
||||
"utils/system/crc32c.cc"
|
||||
"common/*.cc"
|
||||
"parallel/*.cc"
|
||||
"pipeline/pipeline.cc"
|
||||
|
@ -100,14 +117,14 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"debug/*.cc"
|
||||
"onnx/onnx_exporter.cc"
|
||||
"operator/*.cc"
|
||||
"transform/*.cc"
|
||||
"session/kernel_graph.cc"
|
||||
"utils/node_utils.cc"
|
||||
"session/session_basic.cc"
|
||||
"session/session_factory.cc"
|
||||
"session/anf_runtime_algorithm.cc"
|
||||
"vm/*.cc"
|
||||
"pynative/*.cc"
|
||||
"pynative/base.cc"
|
||||
"pynative/pynative_execute.cc"
|
||||
"pybind_api/*.cc"
|
||||
"device/common/*.cc"
|
||||
"kernel/kernel_query.cc"
|
||||
|
@ -117,7 +134,6 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"device/kernel_runtime.cc"
|
||||
"device/kernel_runtime_manager.cc"
|
||||
"device/convert_tensor_utils.cc"
|
||||
"pre_activate/ascend/*.cc"
|
||||
"pre_activate/common/*.cc"
|
||||
"pre_activate/pass/*.cc"
|
||||
"pre_activate/gpu/*.cc"
|
||||
|
@ -168,6 +184,16 @@ if(ENABLE_DUMP_PROTO)
|
|||
add_compile_definitions(ENABLE_DUMP_PROTO)
|
||||
endif()
|
||||
|
||||
if(ENABLE_GE)
|
||||
file(GLOB_RECURSE GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"transform/*.cc"
|
||||
"pynative/pynative_execute_ge.cc"
|
||||
"utils/callbacks_ge.cc"
|
||||
"pipeline/pipeline_ge.cc"
|
||||
)
|
||||
list(APPEND MINDSPORE_SRC_LIST ${GE_SRC_LIST})
|
||||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
include_directories("${CMAKE_BINARY_DIR}/kernel/aicpu")
|
||||
file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
@ -185,10 +211,12 @@ if(ENABLE_D)
|
|||
"device/ascend/*.cc"
|
||||
"device/ascend/profiling/*.cc"
|
||||
"device/ascend/tasksink/*.cc"
|
||||
"kernel/akg/cce/*.cc"
|
||||
"device/kernel_adjust.cc"
|
||||
"kernel/kernel_fusion.cc"
|
||||
"kernel/tbe/*.cc"
|
||||
"pre_activate/ascend/*.cc"
|
||||
"transform/*.cc"
|
||||
"pipeline/pipeline_ge.cc"
|
||||
)
|
||||
list(APPEND MINDSPORE_SRC_LIST ${D_SRC_LIST})
|
||||
list(APPEND MINDSPORE_PROTO_AICPU_LIST ${PROTOSRCS})
|
||||
|
@ -247,9 +275,11 @@ if (ENABLE_GE)
|
|||
target_link_libraries(mindspore graph ge_client)
|
||||
endif()
|
||||
target_link_libraries(mindspore tsdclient)
|
||||
else()
|
||||
elseif(ENABLE_D)
|
||||
add_compile_definitions(NO_GE_CLIENT)
|
||||
target_link_libraries(mindspore graph)
|
||||
else()
|
||||
add_compile_definitions(NO_GE_CLIENT)
|
||||
endif()
|
||||
|
||||
if(ENABLE_D)
|
||||
|
@ -265,15 +295,17 @@ if(ENABLE_D)
|
|||
endif()
|
||||
else()
|
||||
MESSAGE("use system default lib")
|
||||
set(D_LIB_PATH "/usr/local/HiAI/runtime/lib64/")
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
endif()
|
||||
|
||||
MESSAGE("USE DAV LIB PATH: ${D_LIB_PATH}")
|
||||
find_library(HCCL hccl ${D_LIB_PATH})
|
||||
find_library(CCE_LIB cce ${D_LIB_PATH})
|
||||
find_library(RUNTIME_LIB runtime ${D_LIB_PATH})
|
||||
find_library(TSDCLIENT tsdclient ${D_LIB_PATH})
|
||||
find_library(PROFILING msprof ${D_LIB_PATH})
|
||||
MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}")
|
||||
find_library(HCCL hccl ${ASCEND_RUNTIME_PATH})
|
||||
find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH})
|
||||
find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH})
|
||||
find_library(TSDCLIENT tsdclient ${ASCEND_RUNTIME_PATH})
|
||||
find_library(PROFILING msprof ${ASCEND_DRIVER_PATH})
|
||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${TSDCLIENT})
|
||||
endif()
|
||||
|
||||
|
@ -289,8 +321,6 @@ endif()
|
|||
set(PYTHON_MODULE_SOURCE
|
||||
pipeline/init.cc
|
||||
kernel/oplib/oplib.cc
|
||||
kernel/akg/akgkernelbuild.cc
|
||||
kernel/akg/akg_kernel_attrs_process.cc
|
||||
${MS_STEPS_SRC_LIST} ${MS_CCE_SRC_LIST} ${MS_AICPU_SRC_LIST} ${MS_TASKINFO_LIST} ${MS_RT_SRC_LIST}
|
||||
${GPU_NCCL_LIST} ${MS_HCCL_SRC_LIST} ${MS_PREDICT_SRC_LIST} ${CPU_SRC_LIST} ${MEM_REUSE_SRC_LIST} ${GPU_KERNEL_SRC_LIST})
|
||||
|
||||
|
@ -351,6 +381,7 @@ if(ENABLE_GPU)
|
|||
assign_source_group("Include" ${GROUP_INCLUDE})
|
||||
|
||||
file(GLOB COMPILER_SRCS
|
||||
"pre_activate/gpu/*.cc"
|
||||
${TVM_DIR}/src/api/*.cc
|
||||
${TVM_DIR}/src/arithmetic/*.cc
|
||||
${TVM_DIR}/src/autotvm/*.cc
|
||||
|
@ -361,7 +392,6 @@ if(ENABLE_GPU)
|
|||
${TVM_DIR}/src/node/*.cc
|
||||
${TVM_DIR}/src/schedule/*.cc
|
||||
${TVM_DIR}/src/runtime/*.cc
|
||||
${TVM_DIR}/src/runtime/cce/*.cc
|
||||
${TVM_DIR}/src/runtime/vm/*.cc
|
||||
${TVM_DIR}/src/runtime/vm/profiler/*.cc
|
||||
${TVM_DIR}/src/codegen/stackvm/*.cc)
|
||||
|
@ -379,7 +409,6 @@ if(ENABLE_GPU)
|
|||
|
||||
file(GLOB RUNTIME_SRCS
|
||||
${TVM_DIR}/src/runtime/*.cc
|
||||
${TVM_DIR}/src/runtime/cce/*.cc
|
||||
${TVM_DIR}/src/runtime/vm/*.cc
|
||||
${TVM_DIR}/src/runtime/stub/*.cc
|
||||
${TVM_DIR}/src/runtime/stackvm/*.cc)
|
||||
|
@ -470,17 +499,19 @@ add_dependencies(add_ms_lib _c_expression)
|
|||
|
||||
if (NOT ENABLE_GE)
|
||||
if (ENABLE_D)
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||
add_custom_target(add_ge_lib ALL
|
||||
COMMAND cp ${MS_CCSRC_BUILD_PATH}/../../graphengine/src/common/graph/libgraph.so ${MS_LIB_PATH}
|
||||
COMMAND cp ${MS_CCSRC_BUILD_PATH}/../../graphengine/src/ge/common/libge_common.so ${MS_LIB_PATH}
|
||||
COMMAND cp ${MS_CCSRC_BUILD_PATH}/../../graphengine/src/ge/ge_runtime/libge_runtime.so ${MS_LIB_PATH}
|
||||
COMMAND cp /usr/local/HiAI/driver/lib64/libslog.so ${MS_LIB_PATH}
|
||||
COMMAND cp /usr/local/HiAI/driver/lib64/libc_sec.so ${MS_LIB_PATH}
|
||||
COMMAND cp ${ASCEND_DRIVER_PATH}/libslog.so ${MS_LIB_PATH}
|
||||
COMMAND cp ${ASCEND_DRIVER_PATH}/libc_sec.so ${MS_LIB_PATH}
|
||||
)
|
||||
add_dependencies(add_ge_lib add_ms_lib)
|
||||
add_dependencies(add_ge_lib graph)
|
||||
add_dependencies(add_ge_lib ge_runtime)
|
||||
else()
|
||||
elseif(ENABLE_TESTCASES)
|
||||
add_custom_target(add_ge_lib ALL
|
||||
COMMAND cp ${MS_CCSRC_BUILD_PATH}/../../graphengine/src/common/graph/libgraph.so ${MS_LIB_PATH}
|
||||
COMMAND cp ${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so ${MS_LIB_PATH}
|
||||
|
|
|
@ -53,6 +53,7 @@ enum DataTypeTransMode {
|
|||
FROM_INT8_TO_FLOAT,
|
||||
FROM_INT8_TO_INT32,
|
||||
FROM_INT64_TO_INT32,
|
||||
FROM_UINT16_TO_INT32,
|
||||
};
|
||||
|
||||
const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
|
||||
|
@ -68,7 +69,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
|
|||
{std::pair<TypeId, TypeId>(kNumberTypeUInt8, kNumberTypeInt32), FROM_UINT8_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeFloat32), FROM_INT8_TO_FLOAT},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt8, kNumberTypeInt32), FROM_INT8_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}};
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32},
|
||||
{std::pair<TypeId, TypeId>(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) {
|
||||
|
@ -116,8 +118,11 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const
|
|||
case FROM_INT64_TO_INT32:
|
||||
TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
case FROM_UINT16_TO_INT32:
|
||||
TransDataSrc2Dst<uint16_t, int32_t>(args, dst, data_size);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "unsupported datatype trans";
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -127,7 +132,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|||
const size_t default_error = 0;
|
||||
auto dt_size = TypeIdSize(data_type);
|
||||
if (dt_size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return default_error;
|
||||
} else if (dt_size == 1) {
|
||||
return kCubeSize * 2;
|
||||
|
@ -141,12 +146,12 @@ size_t ShapeSize(const std::vector<size_t> &shape) {
|
|||
}
|
||||
|
||||
size_t TypeIdSize(const TypeId data_type) {
|
||||
const size_t unsupport_type_error = 0;
|
||||
const size_t unsupported_type_error = 0;
|
||||
auto iter = type_map.find(data_type);
|
||||
if (iter != type_map.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return unsupport_type_error;
|
||||
return unsupported_type_error;
|
||||
}
|
||||
|
||||
std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) {
|
||||
|
@ -169,7 +174,7 @@ std::vector<size_t> TransShapeTo4d(const std::vector<size_t> &shape) {
|
|||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpeted shape size = " << shape.size();
|
||||
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
@ -178,7 +183,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
std::vector<size_t> device_shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
if (shape.size() < 2) {
|
||||
MS_EXCEPTION(NotSupportError) << "format " << format << " is not support shape " << shape.size();
|
||||
MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size();
|
||||
}
|
||||
if (shape.size() > 2) {
|
||||
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
||||
|
@ -226,37 +231,37 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
}
|
||||
|
||||
bool TransDataType(const TypeIdArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
|
||||
MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to "
|
||||
<< TypeIdLabel(args.device_data_type);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::pair<TypeId, TypeId> type_info(args.host_data_type, args.device_data_type);
|
||||
auto iter = mode_map.find(type_info);
|
||||
if (iter == mode_map.end()) {
|
||||
MS_LOG(ERROR) << "unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
|
||||
MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type)
|
||||
<< ", dst_type:" << TypeIdLabel(args.device_data_type);
|
||||
return false;
|
||||
}
|
||||
auto trans_mode = iter->second;
|
||||
auto type_size = TypeIdSize(args.device_data_type);
|
||||
if (type_size < 1) {
|
||||
MS_LOG(ERROR) << "invalid host data type.";
|
||||
MS_LOG(ERROR) << "Invalid host data type.";
|
||||
return false;
|
||||
}
|
||||
if (args.host_shape_size < 1) {
|
||||
MS_LOG(ERROR) << "invalid host data size.";
|
||||
MS_LOG(ERROR) << "Invalid host data size.";
|
||||
return false;
|
||||
}
|
||||
if (!CastKernel(args, result, args.host_shape_size, trans_mode)) {
|
||||
MS_LOG(ERROR) << "failed to trans datatype..";
|
||||
MS_LOG(ERROR) << "Failed to trans datatype..";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransFormat(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "start trans format.";
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "invalid datatype..";
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
|
||||
|
@ -271,9 +276,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "start trans format.";
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "invalid datatype..";
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) &&
|
||||
|
@ -288,15 +293,15 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool NchwToFracZ(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from nchw to frac_z";
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_z";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
|
@ -306,7 +311,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
|
||||
size_t c0 = CubeSizeByType(args.src_data_type);
|
||||
if (c0 < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = Ceil(c, c0);
|
||||
|
@ -322,7 +327,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
|
||||
size_t dst_size = total_ele_cnt * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size."
|
||||
MS_LOG(ERROR) << "Illegal total data size."
|
||||
<< "dst size is :" << dst_size << "device size is :" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
@ -364,20 +369,20 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool FracZToNchw(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from frac_z to nchw";
|
||||
MS_LOG(DEBUG) << "Trans format from frac_z to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t total_size = ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -430,7 +435,7 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
|
|||
bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *hw_shape) {
|
||||
MS_EXCEPTION_IF_NULL(hw_shape);
|
||||
if (host_shape.empty()) {
|
||||
MS_LOG(ERROR) << "size of vector is 0.";
|
||||
MS_LOG(ERROR) << "Size of vector is 0.";
|
||||
return false;
|
||||
}
|
||||
switch (host_shape.size()) {
|
||||
|
@ -442,7 +447,7 @@ bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *
|
|||
default:
|
||||
auto size = host_shape.size();
|
||||
if (size < 2) {
|
||||
MS_LOG(ERROR) << "illegal size.";
|
||||
MS_LOG(ERROR) << "Illegal size.";
|
||||
return false;
|
||||
}
|
||||
size_t times = 1;
|
||||
|
@ -457,26 +462,26 @@ bool TransShapeToNz(const std::vector<size_t> &host_shape, std::vector<size_t> *
|
|||
}
|
||||
|
||||
bool NchwToFracNz(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from nchw to frac_nz.";
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to frac_nz.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::vector<size_t> hw_shape;
|
||||
if (!TransShapeToNz(args.host_shape, &hw_shape)) {
|
||||
MS_LOG(ERROR) << "trans shape failed..";
|
||||
MS_LOG(ERROR) << "Trans shape failed..";
|
||||
return false;
|
||||
}
|
||||
if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
|
||||
MS_LOG(ERROR) << "invalid shape size.";
|
||||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype";
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
|
@ -533,26 +538,26 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool FracNzToNchw(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from frac_nz to nchw";
|
||||
MS_LOG(DEBUG) << "Trans format from frac_nz to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::vector<size_t> hw_shape;
|
||||
if (!TransShapeToNz(args.host_shape, &hw_shape)) {
|
||||
MS_LOG(ERROR) << "trans shape failed..";
|
||||
MS_LOG(ERROR) << "Trans shape failed..";
|
||||
return false;
|
||||
}
|
||||
if (hw_shape.size() < 3 || args.device_shape.size() < 4) {
|
||||
MS_LOG(ERROR) << "invalid shape size.";
|
||||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype";
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
auto times = hw_shape.at(0);
|
||||
|
@ -609,20 +614,20 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from nchw to Nc1h1wc0";
|
||||
MS_LOG(DEBUG) << "Trans format from nchw to Nc1h1wc0";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -632,7 +637,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
auto w = args.host_shape[3];
|
||||
size_t c0 = CubeSizeByType(args.src_data_type);
|
||||
if (c0 < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = Ceil(c, c0);
|
||||
|
@ -682,20 +687,20 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
}
|
||||
|
||||
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
|
||||
MS_LOG(DEBUG) << "trans format from nc1h1wc0 to nchw";
|
||||
MS_LOG(DEBUG) << "Trans format from nc1h1wc0 to nchw";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
if (args.host_shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "illegal dtype.";
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t total_size = ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|||
data_schema_(std::move(data_schema)),
|
||||
filename_index_(make_unique<StringIndex>()),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true),
|
||||
num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
equal_rows_per_shard_(equal_rows_per_shard) {
|
||||
|
@ -203,6 +204,25 @@ Status TFReaderOp::operator()() {
|
|||
buffer_id++;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||
} else {
|
||||
// user specified number of rows they want, and we read enough rows
|
||||
//
|
||||
// IOBlockQueue thread needs to:
|
||||
// -stop pushing stuff to IOBlockQueue
|
||||
// -call PostEndOfEpoch (will send EOE)
|
||||
// -wait for reset
|
||||
//
|
||||
// Worker threads need to:
|
||||
// -stop reading the file they are currently reading and throw it away
|
||||
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
||||
//
|
||||
// Master thread needs to:
|
||||
// -tell IOBlockQueue thread to stop pushing
|
||||
// -tell worker threads to stop reading the file tey are currently reading
|
||||
// -keep pulling until EOE
|
||||
|
||||
// don't think we need a lock for now
|
||||
load_jagged_connector_ = false;
|
||||
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
|
@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) {
|
|||
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << ".";
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
|
@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
|||
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
||||
|
||||
while (reader.peek() != EOF) {
|
||||
if (!load_jagged_connector_) {
|
||||
break;
|
||||
}
|
||||
|
||||
// read length
|
||||
int64_t record_length = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
||||
|
@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
Status TFReaderOp::Reset() {
|
||||
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
||||
load_jagged_connector_ = true;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
|
|
|
@ -369,6 +369,7 @@ class TFReaderOp : public ParallelOp {
|
|||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
bool load_io_block_queue_;
|
||||
bool load_jagged_connector_;
|
||||
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
|
|
|
@ -141,7 +141,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
|
|||
|
||||
void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMap<AnfNodePtr, int32_t> *para_map) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(INFO) << "param graph is nullptr.";
|
||||
MS_LOG(INFO) << "Param graph is nullptr.";
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> parameters = graph->parameters();
|
||||
|
@ -175,17 +175,17 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa
|
|||
if (para_map != nullptr) {
|
||||
(*para_map)[p] = para++;
|
||||
}
|
||||
MS_LOG(DEBUG) << "record param: " << p->ToString() << " graph belong : " << p->func_graph()->ToString();
|
||||
MS_LOG(DEBUG) << "Record param: " << p->ToString() << " graph belong : " << p->func_graph()->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
void DumpOperator(const AnfNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &gsub) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(INFO) << "param op is nullptr";
|
||||
MS_LOG(INFO) << "Param op is nullptr";
|
||||
return;
|
||||
}
|
||||
if (gsub == nullptr) {
|
||||
MS_LOG(INFO) << "param gsub is nullptr";
|
||||
MS_LOG(INFO) << "Param gsub is nullptr";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -338,7 +338,7 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
|
|||
}
|
||||
|
||||
if (nd->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "input of apply node is empty";
|
||||
MS_LOG(EXCEPTION) << "Input of apply node is empty";
|
||||
}
|
||||
|
||||
// print operator
|
||||
|
@ -376,7 +376,7 @@ void DumpIRInSubgraph(const std::vector<AnfNodePtr> &nodes, OrderedMap<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(nd);
|
||||
FuncGraphPtr sub_graph = nd->func_graph();
|
||||
if (sub_graph == nullptr) {
|
||||
MS_LOG(DEBUG) << "node[" << nd->ToString() << "] belongs to no graph!";
|
||||
MS_LOG(DEBUG) << "Node[" << nd->ToString() << "] belongs to no graph!";
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<SubGraphIRInfo> gsub = (*sub_graphs)[sub_graph];
|
||||
|
@ -430,12 +430,12 @@ void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_fu
|
|||
return;
|
||||
}
|
||||
if (filename.size() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "file path " << filename << " is too long.";
|
||||
MS_LOG(ERROR) << "File path " << filename << " is too long.";
|
||||
return;
|
||||
}
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (nullptr == realpath(filename.c_str(), real_path)) {
|
||||
MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
|
||||
MS_LOG(DEBUG) << "Dir " << filename << " does not exit.";
|
||||
}
|
||||
|
||||
OrderedMap<AnfNodePtr, int32_t> para_map;
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "utils/utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// max number of elements in sequence
|
||||
|
@ -48,7 +49,7 @@ std::string GetMsIrPath(void) {
|
|||
path = path_ptr;
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (path.size() > PATH_MAX || nullptr == realpath(path.c_str(), real_path)) {
|
||||
MS_LOG(EXCEPTION) << "MS IR Path error, " << path_ptr;
|
||||
MS_LOG(EXCEPTION) << "MS IR path error, " << path_ptr;
|
||||
}
|
||||
path = real_path;
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) {
|
|||
|
||||
// ============================================= MindSpore IR Exporter =============================================
|
||||
|
||||
std::string GetNodeType(const AnfNodePtr& nd) {
|
||||
std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) {
|
||||
abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(nd->Shape());
|
||||
TypePtr type = dyn_cast<Type>(nd->Type());
|
||||
std::ostringstream oss;
|
||||
|
@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr&
|
|||
FuncGraphPtr fg = func_graph;
|
||||
while (fg != nullptr) {
|
||||
if (exported.find(fg) == exported.end()) {
|
||||
if (!export_used_) {
|
||||
if (!check_integrity_) {
|
||||
break;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'";
|
||||
|
@ -143,8 +144,8 @@ std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNod
|
|||
}
|
||||
|
||||
std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) {
|
||||
auto py_funs = mt_func_graph->GetPyFunctions();
|
||||
if (py_funs.empty()) {
|
||||
auto py_funcs = mt_func_graph->GetPyFunctions();
|
||||
if (py_funcs.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
|
@ -152,7 +153,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
|||
|
||||
oss << "{";
|
||||
bool is_first = true;
|
||||
for (const auto& py_func : py_funs) {
|
||||
for (const auto& py_func : py_funcs) {
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
|
@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) {
|
|||
}
|
||||
|
||||
// output primitive attributes
|
||||
auto attrs = prim->attrs();
|
||||
if (attrs.size() > 0) {
|
||||
oss << "[";
|
||||
int i = 0;
|
||||
for (auto& attr : attrs) {
|
||||
oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText();
|
||||
i++;
|
||||
oss << prim->GetAttrsText();
|
||||
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(prim);
|
||||
auto& func = do_signature->function();
|
||||
if (func->isa<Primitive>()) {
|
||||
auto sig_prim = dyn_cast<Primitive>(func);
|
||||
oss << sig_prim->GetAttrsText();
|
||||
}
|
||||
oss << "]";
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
|
@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value
|
|||
std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) {
|
||||
std::ostringstream oss;
|
||||
|
||||
if (export_used_) {
|
||||
if (check_integrity_) {
|
||||
MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText();
|
||||
}
|
||||
oss << value->type_name() << "[" << value->DumpText() << "]";
|
||||
|
@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An
|
|||
}
|
||||
oss << "%" << iter->second;
|
||||
} else if (node->isa<Parameter>()) {
|
||||
oss << "%para" << GetParamIndex(func_graph, node, export_used_);
|
||||
oss << "%para" << GetParamIndex(func_graph, node, check_integrity_);
|
||||
} else if (IsValueNode<FuncGraph>(node)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
|
||||
oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id();
|
||||
|
@ -625,7 +626,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt
|
|||
ofs << "\n\n";
|
||||
(void)func_graph_set.erase(fg);
|
||||
}
|
||||
ofs << "# num of total funcgraphs: " << exported.size();
|
||||
ofs << "# num of total function graphs: " << exported.size();
|
||||
|
||||
ofs.close();
|
||||
}
|
||||
|
@ -650,7 +651,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector
|
|||
ofs << "\n\n";
|
||||
}
|
||||
|
||||
ofs << "# num of total funcgraphs: " << graphs.size();
|
||||
ofs << "# num of total function graphs: " << graphs.size();
|
||||
|
||||
ofs.close();
|
||||
}
|
||||
|
@ -762,7 +763,7 @@ class Lexer {
|
|||
fin.close();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
MS_LOG(ERROR) << "exception when closing file";
|
||||
MS_LOG(ERROR) << "Exception when closing file";
|
||||
} catch (...) {
|
||||
std::string exName(abi::__cxa_current_exception_type()->name());
|
||||
MS_LOG(ERROR) << "Error occurred when closing file. Exception name: " << exName;
|
||||
|
@ -801,7 +802,7 @@ class Lexer {
|
|||
Token token = GetNextTokenInner();
|
||||
const char* str = token_text[token];
|
||||
std::string text = (str == nullptr ? GetTokenText() : str);
|
||||
MS_LOG(DEBUG) << "------parse token] " << text;
|
||||
MS_LOG(DEBUG) << "------Parse token] " << text;
|
||||
return token;
|
||||
}
|
||||
|
||||
|
@ -1641,7 +1642,7 @@ class IrParser {
|
|||
MS_LOG(EXCEPTION) << "Expect @file at line " << lexer_.GetLineNo();
|
||||
}
|
||||
|
||||
// load prameter default value from serialized file
|
||||
// load parameter default value from serialized file
|
||||
py::object default_obj = LoadObject(lexer_.GetTokenText());
|
||||
param->set_default_param(default_obj);
|
||||
|
||||
|
@ -1949,7 +1950,7 @@ class IrParser {
|
|||
return TOK_ERROR;
|
||||
}
|
||||
|
||||
// restore python funciton of PrimitivePy from serialized file
|
||||
// restore python function of PrimitivePy from serialized file
|
||||
py::object py_obj = LoadObject(lexer_.GetTokenText());
|
||||
PrimitivePyPtr ptr = nullptr;
|
||||
if (py::hasattr(py_obj, "__setattr_flag__") && py::hasattr(py_obj, "_clone")) {
|
||||
|
@ -1957,7 +1958,7 @@ class IrParser {
|
|||
py::object new_obj = clone_fn();
|
||||
ptr = new_obj.cast<PrimitivePyPtr>();
|
||||
if (ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "cast to type 'PrimitivePyPtr' error";
|
||||
MS_LOG(EXCEPTION) << "Cast to type 'PrimitivePyPtr' error";
|
||||
}
|
||||
} else {
|
||||
ptr = std::make_shared<PrimitivePy>(id.substr(strlen("PrimitivePy::")), py_obj);
|
||||
|
@ -2220,15 +2221,15 @@ class IrParser {
|
|||
};
|
||||
|
||||
std::vector<FuncGraphPtr> ImportIR(const std::string& filename) {
|
||||
IrParser paser(filename.c_str());
|
||||
paser.ParseFile();
|
||||
return paser.GetFuncGraphs();
|
||||
IrParser parser(filename.c_str());
|
||||
parser.ParseFile();
|
||||
return parser.GetFuncGraphs();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr";
|
||||
MS_LOG(ERROR) << "Func graph is nullptr";
|
||||
return;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -2242,16 +2243,16 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) {
|
|||
}
|
||||
std::string file_path = save_graphs_path + "/" + "ms_output_" + suffix + ".pb";
|
||||
if (file_path.size() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "file path " << file_path << " is too long.";
|
||||
MS_LOG(ERROR) << "File path " << file_path << " is too long.";
|
||||
return;
|
||||
}
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (nullptr == realpath(file_path.c_str(), real_path)) {
|
||||
MS_LOG(DEBUG) << "dir " << file_path << " does not exit.";
|
||||
MS_LOG(DEBUG) << "Dir " << file_path << " does not exit.";
|
||||
} else {
|
||||
std::string path_string = real_path;
|
||||
if (chmod(common::SafeCStr(path_string), S_IRUSR | S_IWUSR) == -1) {
|
||||
MS_LOG(ERROR) << "modify file:" << real_path << " to rw fail.";
|
||||
MS_LOG(ERROR) << "Modify file:" << real_path << " to rw fail.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,17 +64,18 @@ struct ParamPtrHasher {
|
|||
|
||||
class AnfExporter {
|
||||
public:
|
||||
explicit AnfExporter(const std::string& id, bool export_used = true)
|
||||
: param_index(-1), id_(id), export_used_(export_used) {
|
||||
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
|
||||
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
|
||||
func_graph_set.clear();
|
||||
exported.clear();
|
||||
}
|
||||
~AnfExporter() {}
|
||||
virtual ~AnfExporter() {}
|
||||
|
||||
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
|
||||
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
|
||||
|
||||
private:
|
||||
protected:
|
||||
virtual std::string GetNodeType(const AnfNodePtr& nd);
|
||||
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
|
||||
int GetParamIndexFromExported(const AnfNodePtr& param);
|
||||
std::string DumpObject(const py::object& obj, const std::string& category) const;
|
||||
|
@ -101,8 +102,10 @@ class AnfExporter {
|
|||
OrderedSet<FuncGraphPtr> func_graph_set{};
|
||||
OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>> exported;
|
||||
std::string id_;
|
||||
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
||||
bool export_used_ = true; // whether export function graphs used in current exporting function graph
|
||||
bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true
|
||||
TaggedNodeMap tagged_cnodes_;
|
||||
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
|
||||
};
|
||||
|
||||
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
|
||||
|
@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
|
|||
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
|
||||
|
||||
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
|
||||
std::string GetNodeType(const AnfNodePtr& nd);
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
|
||||
|
|
|
@ -362,7 +362,7 @@ Digraph::~Digraph() {
|
|||
fout_.close();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
MS_LOG(ERROR) << "exception when closing file " << filename_;
|
||||
MS_LOG(ERROR) << "Exception when closing file " << filename_;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -208,7 +208,7 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value
|
|||
TypePtr elem_type = dyn_cast<TensorType>(val)->element();
|
||||
type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type));
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Not supported type " << val->type_name();
|
||||
MS_LOG(WARNING) << "Unsupported type " << val->type_name();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool Dump::ParseDumpConfig(const string& dump_config_file) {
|
||||
bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
|
||||
std::ifstream jsonFile(dump_config_file);
|
||||
if (!jsonFile.is_open()) {
|
||||
MS_LOG(ERROR) << dump_config_file << " open failed.";
|
||||
|
@ -101,7 +101,7 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
|
|||
auto kernels = dumpSettings.at("kernels");
|
||||
if (!(enable.is_boolean() && trans_flag.is_boolean() && mode.is_number() && path.is_string() &&
|
||||
net_name.is_string() && iteration.is_number() && kernels.is_array())) {
|
||||
MS_LOG(ERROR) << "element's type in Dump config json is invalid.";
|
||||
MS_LOG(ERROR) << "Element's type in Dump config json is invalid.";
|
||||
dump_enable_ = false;
|
||||
return false;
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
|
|||
bool Dump::SetDumpConfFromJsonFile() {
|
||||
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
|
||||
if (config_path_str != nullptr) {
|
||||
MS_LOG(INFO) << "getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
|
||||
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
|
||||
} else {
|
||||
MS_LOG(INFO) << "No need E2E Dump. please export MINDSPORE_CONFIG_PATH eg: MINDSPORE_CONFIG_PATH=/etc";
|
||||
dump_enable_ = false;
|
||||
|
@ -132,7 +132,7 @@ bool Dump::SetDumpConfFromJsonFile() {
|
|||
auto id = context_ptr->device_id();
|
||||
char real_path[PATH_MAX] = {0};
|
||||
if (nullptr == realpath(config_path_str, real_path)) {
|
||||
MS_LOG(ERROR) << "env e2e dump path error, " << config_path_str;
|
||||
MS_LOG(ERROR) << "Env e2e dump path error, " << config_path_str;
|
||||
dump_enable_ = false;
|
||||
return false;
|
||||
}
|
||||
|
@ -150,20 +150,20 @@ bool Dump::SetDumpConfFromJsonFile() {
|
|||
|
||||
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) {
|
||||
if (filename.empty() || data == nullptr || len == 0) {
|
||||
MS_LOG(ERROR) << "incorrect parameter.";
|
||||
MS_LOG(ERROR) << "Incorrect parameter.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string realpath;
|
||||
bool ret = GetRealPath(filename, &realpath);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "get real path failed.";
|
||||
MS_LOG(ERROR) << "Get real path failed.";
|
||||
return false;
|
||||
}
|
||||
std::ofstream fd;
|
||||
fd.open(realpath, std::ios::binary | std::ios::out);
|
||||
if (!fd.is_open()) {
|
||||
MS_LOG(ERROR) << "open file " << realpath << " fail.";
|
||||
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
|
||||
return false;
|
||||
}
|
||||
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len));
|
||||
|
@ -182,7 +182,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
|
|||
if (path_split_pos != std::string::npos) {
|
||||
std::string prefix_path = inpath.substr(0, path_split_pos);
|
||||
if (prefix_path.length() >= PATH_MAX) {
|
||||
MS_LOG(ERROR) << "prefix path is too longer!";
|
||||
MS_LOG(ERROR) << "Prefix path is too longer!";
|
||||
return false;
|
||||
}
|
||||
std::string last_path = inpath.substr(path_split_pos, inpath.length() - path_split_pos);
|
||||
|
@ -201,11 +201,11 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
|
|||
|
||||
if (path_split_pos == std::string::npos) {
|
||||
if (inpath.length() >= PATH_MAX) {
|
||||
MS_LOG(ERROR) << "prefix path is too longer!";
|
||||
MS_LOG(ERROR) << "Prefix path is too longer!";
|
||||
return false;
|
||||
}
|
||||
if (nullptr == realpath(inpath.c_str(), real_path)) {
|
||||
MS_LOG(ERROR) << "file " << inpath << " does not exit, it will be created.";
|
||||
MS_LOG(ERROR) << "File " << inpath << " does not exit, it will be created.";
|
||||
}
|
||||
*outpath = std::string(real_path);
|
||||
}
|
||||
|
@ -218,7 +218,7 @@ bool Dump::CreateNotExistDirs(const std::string& path) {
|
|||
MS_EXCEPTION_IF_NULL(fs);
|
||||
char temp_path[PATH_MAX] = {0};
|
||||
if (path.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "path lens is max than " << PATH_MAX;
|
||||
MS_LOG(ERROR) << "Path lens is max than " << PATH_MAX;
|
||||
return false;
|
||||
}
|
||||
for (uint32_t i = 0; i < path.length(); i++) {
|
||||
|
@ -229,7 +229,7 @@ bool Dump::CreateNotExistDirs(const std::string& path) {
|
|||
temp_path[i] = '\0';
|
||||
std::string path_handle(temp_path);
|
||||
if (!fs->FileExist(temp_path)) {
|
||||
MS_LOG(INFO) << "dir " << path_handle << " does not exit, creating...";
|
||||
MS_LOG(INFO) << "Dir " << path_handle << " does not exit, creating...";
|
||||
if (!fs->CreateDir(temp_path)) {
|
||||
MS_LOG(ERROR) << "Create " << path_handle << " dir error";
|
||||
return false;
|
||||
|
@ -241,7 +241,7 @@ bool Dump::CreateNotExistDirs(const std::string& path) {
|
|||
}
|
||||
|
||||
if (!fs->FileExist(path)) {
|
||||
MS_LOG(INFO) << "dir " << path << " does not exit, creating...";
|
||||
MS_LOG(INFO) << "Dir " << path << " does not exit, creating...";
|
||||
if (!fs->CreateDir(path)) {
|
||||
MS_LOG(ERROR) << "Create " << path << " dir error";
|
||||
return false;
|
||||
|
|
|
@ -193,7 +193,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
|
|||
}
|
||||
TraceContextPtr context = std::make_shared<TraceContext>(trace_info);
|
||||
if (trace_info->debug_info() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "trace debug info is null";
|
||||
MS_LOG(EXCEPTION) << "Trace debug info is null";
|
||||
}
|
||||
TraceManager::trace_context_stack_.push(context);
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr
|
|||
auto cloned_info = trace_info->clone();
|
||||
cloned_info->set_debug_info(debug_info);
|
||||
if (cloned_info->debug_info() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "trace debug info is null with cloned trace";
|
||||
MS_LOG(EXCEPTION) << "Trace debug info is null with cloned trace";
|
||||
}
|
||||
TraceContextPtr context = std::make_shared<TraceContext>(cloned_info);
|
||||
TraceManager::trace_context_stack_.push(context);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "debug/trace.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -88,7 +89,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
|
|||
return "";
|
||||
}
|
||||
|
||||
// a trace info identifys a node transform, so we can trace the node transform through
|
||||
// a trace info identifies a node transform, so we can trace the node transform through
|
||||
// a link of trace info and debug info
|
||||
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) {
|
||||
if (info_vec.size() < 1) {
|
||||
|
@ -172,7 +173,7 @@ void DumpInferStack(std::ostringstream& oss) {
|
|||
}
|
||||
auto graph_context = graph_infer->graph_context();
|
||||
if (graph_context == nullptr) {
|
||||
MS_LOG(INFO) << "null context continue";
|
||||
MS_LOG(INFO) << "Null context continue";
|
||||
continue;
|
||||
}
|
||||
auto graph = graph_context->func_graph();
|
||||
|
@ -194,37 +195,116 @@ void TraceGraphInfer() {
|
|||
MS_LOG(INFO) << "\n*************************************************************************************";
|
||||
}
|
||||
|
||||
void OutputAnalysisGraphInfo() {
|
||||
MS_LOG(INFO) << "Output analysis graph begin";
|
||||
std::unordered_map<FuncGraphPtr, size_t> index_map;
|
||||
std::vector<TaggedGraph> tagged_graphs;
|
||||
class AnalyzedFuncGraphExporter : public AnfExporter {
|
||||
public:
|
||||
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
|
||||
~AnalyzedFuncGraphExporter() override = default;
|
||||
|
||||
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
|
||||
|
||||
private:
|
||||
std::string GetNodeType(const AnfNodePtr& nd) override;
|
||||
};
|
||||
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
|
||||
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
|
||||
auto& list = GetCNodeDebugStack();
|
||||
for (size_t i = 0; i < list.size(); ++i) {
|
||||
auto& node_cfg = list[i];
|
||||
auto node_cfg = list[i];
|
||||
auto fg = node_cfg->context()->func_graph();
|
||||
auto node = node_cfg->node();
|
||||
auto idx = tagged_graphs.size();
|
||||
std::pair<FuncGraphPtr, size_t> item(fg, idx);
|
||||
if (index_map.insert(item).second) {
|
||||
tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap()));
|
||||
}
|
||||
tagged_graphs[index_map[fg]].second[node] = i;
|
||||
tagged_func_graphs[fg][node] = i;
|
||||
}
|
||||
return tagged_func_graphs;
|
||||
}
|
||||
|
||||
void OutputAnalyzedGraphWithType() {
|
||||
AnalyzedFuncGraphExporter exporter;
|
||||
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
|
||||
}
|
||||
|
||||
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
|
||||
if (node_cfg_ == nullptr) {
|
||||
return AnfExporter::GetNodeType(node);
|
||||
}
|
||||
auto ctx = node_cfg_->context();
|
||||
auto engine = node_cfg_->engine();
|
||||
auto cfg = engine->MakeConfig(node, ctx);
|
||||
auto abs = engine->cache().GetValue(cfg);
|
||||
|
||||
if (abs == nullptr) {
|
||||
return "Undefined";
|
||||
}
|
||||
auto dtype = abs->BuildType();
|
||||
auto shape = abs->BuildShape();
|
||||
std::ostringstream oss;
|
||||
if (dtype != nullptr && abs->isa<abstract::AbstractTensor>() && shape != nullptr) {
|
||||
oss << dtype->DumpText() << shape->DumpText();
|
||||
} else if (dtype != nullptr) {
|
||||
oss << dtype->DumpText();
|
||||
} else {
|
||||
oss << "Undefined";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
|
||||
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
|
||||
if (node_cfgs.empty()) {
|
||||
MS_LOG(DEBUG) << "Node configs is empty";
|
||||
return;
|
||||
}
|
||||
|
||||
ExportIR("analyze_fail.dat", tagged_graphs);
|
||||
MS_LOG(INFO) << "Output analysis graph *end*";
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file '" << filename << "' failed!";
|
||||
return;
|
||||
}
|
||||
|
||||
param_index = 1;
|
||||
auto tagged_func_graphs = CalcTaggedFuncGraphs();
|
||||
|
||||
// first output graph on the analysis stack
|
||||
for (const auto& node_cfg : node_cfgs) {
|
||||
auto fg = node_cfg->context()->func_graph();
|
||||
// the graph is already output, skip it
|
||||
if (exported.find(fg) != exported.end()) {
|
||||
continue;
|
||||
}
|
||||
// set node_cfg info for getting type
|
||||
node_cfg_ = node_cfg;
|
||||
tagged_cnodes_ = tagged_func_graphs[fg];
|
||||
ExportOneFuncGraph(ofs, fg);
|
||||
ofs << "\n\n";
|
||||
}
|
||||
|
||||
node_cfg_ = nullptr;
|
||||
tagged_cnodes_.clear();
|
||||
|
||||
// print seperator between function graphs on analyzed graph call stack and others
|
||||
ofs << "#===============================================================================\n\n\n";
|
||||
|
||||
// second output other graphs
|
||||
while (!func_graph_set.empty()) {
|
||||
FuncGraphPtr fg = *func_graph_set.begin();
|
||||
ExportOneFuncGraph(ofs, fg);
|
||||
ofs << "\n\n";
|
||||
(void)func_graph_set.erase(fg);
|
||||
}
|
||||
ofs << "# num of total function graphs: " << exported.size();
|
||||
|
||||
ofs.close();
|
||||
}
|
||||
|
||||
void GetInferStackInfo(std::ostringstream& oss) {
|
||||
MS_LOG(INFO) << "Get graph analysis information begin";
|
||||
auto& stack = GetCNodeDebugStack();
|
||||
auto stack = GetCNodeDebugStack();
|
||||
if (stack.empty()) {
|
||||
MS_LOG(INFO) << "Length of analysis information stack is empty.";
|
||||
return;
|
||||
}
|
||||
|
||||
OutputAnalysisGraphInfo();
|
||||
OutputAnalyzedGraphWithType();
|
||||
oss << "\nThe function call stack:\n";
|
||||
|
||||
int index = 0;
|
||||
|
@ -252,7 +332,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
|
|||
MS_LOG(INFO) << "Get graph analysis information *end*";
|
||||
}
|
||||
|
||||
// trace the graph evaluator statck
|
||||
// trace the graph evaluator stack
|
||||
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
|
||||
// trace the cnode infer debug info
|
||||
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
|
||||
|
|
|
@ -36,6 +36,6 @@ std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) {
|
|||
} else if (debug_info()->trace_info() != nullptr) {
|
||||
return act_name + debug_info()->trace_info()->GetActionBetweenNode(info);
|
||||
}
|
||||
return "not in the traced info";
|
||||
return "Not in the traced info";
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -106,13 +106,13 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
|
|||
} else {
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto host = std::vector<uint8_t>(size_);
|
||||
const trans::TypeIdArgs type_args{ptr_, shape_size, type_id_, type};
|
||||
sync_ok = trans::TransDataType(type_args, host.data());
|
||||
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type};
|
||||
sync_ok = trans::TransDataType(type_args, host_ptr);
|
||||
if (!sync_ok) {
|
||||
MS_LOG(ERROR) << "trans data type failed.";
|
||||
return false;
|
||||
}
|
||||
SyncMemory(host_ptr, host.data(), size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
}
|
||||
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
|
||||
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
|
||||
|
|
|
@ -150,9 +150,9 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path,
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t j = 0; j < output_size; ++j) {
|
||||
auto addr = AnfAlgo::GetOutputAddr(node, j);
|
||||
auto shape = AnfAlgo::GetOutputDeviceShape(node, j);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(node, j);
|
||||
auto format = AnfAlgo::GetOutputFormat(node, j);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, j);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, j);
|
||||
auto format = kOpFormat_DEFAULT;
|
||||
string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j);
|
||||
auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
|
||||
std::vector<int> int_shapes;
|
||||
|
@ -181,9 +181,9 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
|
|||
continue;
|
||||
}
|
||||
auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto format = AnfAlgo::GetOutputFormat(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX);
|
||||
auto format = kOpFormat_DEFAULT;
|
||||
string filepath = dump_path + '/' + parameter_name + '_' + "output_0";
|
||||
auto ascend_addr = dynamic_cast<const mindspore::device::ascend::AscendDeviceAddress *>(addr);
|
||||
std::vector<int> int_shapes;
|
||||
|
@ -239,22 +239,11 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::MallocOpMemory(const DeviceAddressPtr address, size_t size, int flag) {
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->enable_dynamic_mem_pool()) {
|
||||
auto device_ptr = AscendMemoryAllocator::GetInstance().AllocTensorMem(size);
|
||||
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||
address->ptr_ = device_ptr;
|
||||
address->mem_dynamic_alloc_ = true;
|
||||
return;
|
||||
}
|
||||
if (flag == kStaticMem) {
|
||||
address->ptr_ = MallocStaticMem(size, false);
|
||||
} else if (flag == kDynamicMem) {
|
||||
address->ptr_ = MallocDynamicMem(size, false);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown memory type!";
|
||||
}
|
||||
void AscendKernelRuntime::MallocOpMemory(const DeviceAddressPtr address, size_t size, int) {
|
||||
auto device_ptr = AscendMemoryAllocator::GetInstance().AllocTensorMem(size);
|
||||
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||
address->ptr_ = device_ptr;
|
||||
address->mem_dynamic_alloc_ = true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
|
@ -488,23 +477,18 @@ bool AscendKernelRuntime::DestroyHccl() {
|
|||
|
||||
bool AscendKernelRuntime::MallocDeviceMemory() {
|
||||
device_mem_size_ = ASCEND_MEM_SIZE_BYTE;
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->enable_dynamic_mem_pool()) {
|
||||
static_mem_offset_ = FloatToSize(device_mem_size_ * GRAPH_INIT_DAVINCI_MEM_RATIO);
|
||||
device_mem_pool_size_ = FloatToSize(device_mem_size_ * (1 - GRAPH_INIT_DAVINCI_MEM_RATIO));
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
|
||||
}
|
||||
AscendMemoryAllocator::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
|
||||
AscendMemoryAllocator::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
|
||||
} else {
|
||||
static_mem_offset_ = device_mem_size_;
|
||||
}
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
|
||||
static_mem_offset_ = FloatToSize(device_mem_size_ * GRAPH_INIT_ASCEND_MEM_RATIO);
|
||||
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]";
|
||||
}
|
||||
device_mem_pool_size_ = FloatToSize(device_mem_size_ * (1 - GRAPH_INIT_ASCEND_MEM_RATIO));
|
||||
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
|
||||
}
|
||||
AscendMemoryAllocator::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
|
||||
AscendMemoryAllocator::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ const uint64_t MEM_SIZE_BYTE = (MEM_SIZE << 30);
|
|||
|
||||
AscendMemoryAllocator::AscendMemoryAllocator() {
|
||||
hasMalloc_ = false;
|
||||
free_mem_size_ = FloatToSize(MEM_SIZE_BYTE * (1 - GRAPH_INIT_DAVINCI_MEM_RATIO));
|
||||
free_mem_size_ = FloatToSize(MEM_SIZE_BYTE * (1 - GRAPH_INIT_ASCEND_MEM_RATIO));
|
||||
total_mem_size_ = free_mem_size_;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
// The fraction of total ascend memory used to compute the graph.
|
||||
static const float GRAPH_INIT_DAVINCI_MEM_RATIO = 0.8;
|
||||
static const float GRAPH_INIT_ASCEND_MEM_RATIO = 0.8;
|
||||
|
||||
class AscendMemoryAllocator : public DynamicMemPoolBestFit {
|
||||
public:
|
||||
|
|
|
@ -94,7 +94,7 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke
|
|||
return ret;
|
||||
}
|
||||
|
||||
static vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
|
||||
static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
|
||||
MS_EXCEPTION_IF_NULL(pre_node);
|
||||
std::vector<int> clean_size_list;
|
||||
// clean output
|
||||
|
|
|
@ -35,6 +35,7 @@ enum MatchCountPriority : int {
|
|||
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
||||
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
||||
MATCH_FORMAT_COUNT,
|
||||
MATCH_SPECIAL_FORMAT_COUNT,
|
||||
MATCH_5D_FORMAT_COUNT,
|
||||
MATCH_OUTPUT_DTYPE_COUNT,
|
||||
MATCH_COUNT_PRIORITY_END
|
||||
|
@ -174,12 +175,11 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
continue;
|
||||
}
|
||||
}
|
||||
if (input_anf_node->isa<ValueNode>()) {
|
||||
if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
||||
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) &&
|
||||
kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) {
|
||||
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++;
|
||||
}
|
||||
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
||||
}
|
||||
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
||||
|
@ -203,7 +203,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "common/utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
||||
using std::vector;
|
||||
using Json = nlohmann::json;
|
||||
|
|
|
@ -121,8 +121,8 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
|||
LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs);
|
||||
}
|
||||
|
||||
std::vector<TaskInfoPtr> task_info_ptrs =
|
||||
kernel_mod->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id);
|
||||
std::vector<TaskInfoPtr> task_info_ptrs = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod)
|
||||
->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id);
|
||||
task_info_list->insert(task_info_list->end(), task_info_ptrs.begin(), task_info_ptrs.end());
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -22,12 +22,9 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "cce/aicpu_engine_struct.h"
|
||||
#include "cce/taskdown_api.h"
|
||||
#include "cce/fwk_adpt_struct.h"
|
||||
#include "device/kernel_runtime.h"
|
||||
#include "ir/anf.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/ascend_kernel_mod.h"
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -66,7 +66,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
|||
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
|
||||
if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "value node sync host to device failed!";
|
||||
MS_LOG(EXCEPTION) << "Value node sync host to device failed!";
|
||||
}
|
||||
}
|
||||
address->ref_count_ = INIT_NODE_REF;
|
||||
|
@ -141,7 +141,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t output_size = AnfAlgo::GetOutputTensorNum(node);
|
||||
if (index >= output_size) {
|
||||
MS_LOG(EXCEPTION) << "invalid input index " << index;
|
||||
MS_LOG(EXCEPTION) << "Invalid input index " << index;
|
||||
}
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
|
@ -157,7 +157,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
|
|||
type_id = kNumberTypeFloat32;
|
||||
}
|
||||
if (type_id != kNumberTypeInt32 && type_id != kNumberTypeFloat32) {
|
||||
MS_LOG(EXCEPTION) << "check output type failed.";
|
||||
MS_LOG(EXCEPTION) << "Check output type failed.";
|
||||
}
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -181,7 +181,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
|
|||
// bind input ptr
|
||||
auto &input_nodes = kernel_graph->inputs();
|
||||
if (input_nodes.size() != inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "input size not equal to input node size!";
|
||||
MS_LOG(EXCEPTION) << "Input size not equal to input node size!";
|
||||
}
|
||||
|
||||
std::unordered_map<AnfNode *, tensor::TensorPtr> input_map;
|
||||
|
@ -203,7 +203,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
|
|||
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
|
||||
if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c(false))) {
|
||||
MS_LOG(EXCEPTION) << "parameter node sync host to device failed!";
|
||||
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
|
||||
}
|
||||
tensor->set_dirty(true);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ void CPUResourceManager::MemPlan(const session::KernelGraph *graph) {
|
|||
mem_size_ = graph_mem_size;
|
||||
dynamic_malloc_ = false;
|
||||
} else {
|
||||
MS_LOG(INFO) << "switch to dynamic malloc";
|
||||
MS_LOG(INFO) << "Switch to dynamic malloc";
|
||||
dynamic_malloc_ = true;
|
||||
}
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ void *CPUResourceManager::MemMalloc(size_t mem_size) {
|
|||
dynamic_mem_[ptr] = mem_size;
|
||||
return ptr;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "malloc memory failed: size " << mem_size;
|
||||
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,12 +31,12 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
auto graph_id = graph_sum_;
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "set kernel info";
|
||||
MS_LOG(INFO) << "Set kernel info";
|
||||
SetKernelInfo(graph.get());
|
||||
predictmodel::StepConvertGraph(graph);
|
||||
MS_LOG(INFO) << "build kernel";
|
||||
MS_LOG(INFO) << "Build kernel";
|
||||
BuildKernel(graph.get());
|
||||
MS_LOG(INFO) << "assign kernel address";
|
||||
MS_LOG(INFO) << "Assign kernel address";
|
||||
runtime_.AssignKernelAddress(graph.get());
|
||||
return graph_id;
|
||||
}
|
||||
|
@ -44,18 +44,18 @@ GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
auto &kernel_graph = graphs_[graph_id];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_LOG(INFO) << "bind input output address";
|
||||
MS_LOG(INFO) << "Bind input output address";
|
||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||
MS_LOG(INFO) << "run graph start";
|
||||
MS_LOG(INFO) << "Run graph start";
|
||||
predictmodel::StepConvertWeight(inputs);
|
||||
auto execution_order = kernel_graph->execution_order();
|
||||
Reorder(&execution_order);
|
||||
kernel_graph->set_execution_order(execution_order);
|
||||
bool ret = runtime_.Run(kernel_graph.get());
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "run graph failed";
|
||||
MS_LOG(EXCEPTION) << "Run graph failed";
|
||||
}
|
||||
MS_LOG(INFO) << "run graph end";
|
||||
MS_LOG(INFO) << "Run graph end";
|
||||
}
|
||||
|
||||
void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "operator/ops.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "transform/convert.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
|
|
@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
|
|||
return supported_type_lists;
|
||||
}
|
||||
|
||||
bool SelectAkgKernel(const CNodePtr& kernel_node, const shared_ptr<KernelBuildInfo>& selected_kernel_info) {
|
||||
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "device/ascend/profiling/profiling_manager.h"
|
||||
#include "device/ascend/kernel_select_ascend.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "runtime/base.h"
|
||||
|
||||
constexpr auto kLoopCountParamName = "loop_count";
|
||||
constexpr auto kIterLoopParamName = "iter_loop";
|
||||
|
@ -49,7 +50,7 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g
|
|||
std::vector<CNodePtr> momentum_list;
|
||||
std::vector<CNodePtr> other_list;
|
||||
for (const auto &cnode : origin_cnode_list) {
|
||||
if (kOptOpeatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOpeatorSet.end()) {
|
||||
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) {
|
||||
momentum_list.emplace_back(cnode);
|
||||
} else {
|
||||
other_list.emplace_back(cnode);
|
||||
|
@ -118,7 +119,7 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::Kerne
|
|||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract();
|
||||
if (paremeter_abstract_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "create abstract brfore insert switch op failed!";
|
||||
MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!";
|
||||
}
|
||||
|
||||
ParameterPtr loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
|
||||
|
@ -371,7 +372,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &c
|
|||
auto tensor = inputs[i];
|
||||
size_t deal_index = input_nodes.size() - input_ctrl_size + i;
|
||||
if (deal_index >= input_nodes.size()) {
|
||||
MS_LOG(EXCEPTION) << "deak_index[" << deal_index << "] outof range";
|
||||
MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range";
|
||||
}
|
||||
auto input_node = input_nodes[deal_index];
|
||||
bool need_sync = false;
|
||||
|
@ -439,7 +440,7 @@ void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
|
|||
|
||||
void KernelAdjust::Profiling(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
|
||||
if (!ascend::ProfilingManager::GetInstance().IsProfiling()) {
|
||||
MS_LOG(INFO) << "no need to profiling";
|
||||
MS_LOG(INFO) << "No need to profiling";
|
||||
return;
|
||||
}
|
||||
ProfilingTraceInfo profiling_trace_info;
|
||||
|
@ -452,10 +453,10 @@ void KernelAdjust::Profiling(const std::shared_ptr<session::KernelGraph> &kernel
|
|||
|
||||
void KernelAdjust::InsertProfilingKernel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
|
||||
const ProfilingTraceInfo &profiling_trace_info) {
|
||||
MS_LOG(INFO) << "[profiling] insert profiling kernel start";
|
||||
MS_LOG(INFO) << "[profiling] Insert profiling kernel start";
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
|
||||
if (!profiling_trace_info.IsValid()) {
|
||||
MS_LOG(WARNING) << "profiling trace point not found";
|
||||
MS_LOG(WARNING) << "Profiling trace point not found";
|
||||
return;
|
||||
}
|
||||
std::vector<CNodePtr> new_cnode_list;
|
||||
|
|
|
@ -42,7 +42,6 @@ KernelRuntime::~KernelRuntime() {
|
|||
#ifdef ENABLE_DUMP_E2E
|
||||
dump_conf_ptr_ = nullptr;
|
||||
#endif
|
||||
reuse_mem_base_ = nullptr;
|
||||
mem_reuse_util_ptr_ = nullptr;
|
||||
}
|
||||
|
||||
|
@ -50,12 +49,19 @@ bool KernelRuntime::Run(session::KernelGraph *graph) {
|
|||
bool ret = false;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
bool is_task_sink = context_ptr->enable_task_sink();
|
||||
if (is_task_sink) {
|
||||
ret = RunTask(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
}
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -235,7 +241,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
auto output_size = AnfAlgo::GetOutputTensorNum(item);
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
|
||||
// if graph output is a weight and doesn't link to any cnode,it's data type will be unkonwn
|
||||
// if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
|
||||
|
@ -366,7 +372,7 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
|
|||
continue;
|
||||
}
|
||||
if (AnfAlgo::OutputAddrExist(node, i)) {
|
||||
MS_LOG(INFO) << "already malloc index:" << i;
|
||||
MS_LOG(INFO) << "Already malloc index:" << i;
|
||||
continue;
|
||||
}
|
||||
auto ptr = CalDeviceMem(node, output_sizes[i], flag, i);
|
||||
|
@ -386,7 +392,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
auto tensor = node_value->cast<TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(WARNING) << "tensor is null";
|
||||
MS_LOG(WARNING) << "Tensor is null";
|
||||
return;
|
||||
}
|
||||
size_t tensor_size = tensor->data().nbytes();
|
||||
|
@ -469,9 +475,9 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
|||
bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get());
|
||||
size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize();
|
||||
MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]";
|
||||
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
||||
reuse_mem_base_ = base_ptr;
|
||||
mem_reuse_util_ptr_ = mem_reuse_util_ptr;
|
||||
auto base_ptr = MallocDynamicMem(total_allocated_size, false);
|
||||
mem_reuse_util_ptr_->set_mem_base(base_ptr);
|
||||
auto &kernels = graph->execution_order();
|
||||
for (auto &kernel : kernels) {
|
||||
AssignNodeOutputMem(kReuseDynamicMem, kernel, kGetAllOuts);
|
||||
|
@ -481,22 +487,13 @@ void KernelRuntime::ReuseAssignDynamicMemory(session::KernelGraph *graph) {
|
|||
|
||||
void KernelRuntime::AssignReuseWorkSpaceMem(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto key = node.get();
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t index = 0;
|
||||
auto iter = mem_reuse_util_ptr_->kernel_workspace_refs_.find(key);
|
||||
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
||||
if (iter != mem_reuse_util_ptr_->kernel_workspace_refs_.end()) {
|
||||
if (index >= iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size()
|
||||
<< "]";
|
||||
}
|
||||
auto wk_ref = iter->second[index];
|
||||
auto wk_ptr = reuse_mem_base_ + wk_ref->offset_;
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
||||
index++;
|
||||
}
|
||||
auto wk_ptr = mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(wk_ptr, size, "", kTypeUnknown), index, node.get());
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -547,18 +544,7 @@ uint8_t *KernelRuntime::CalDeviceMem(const AnfNodePtr &node, size_t size, int fl
|
|||
} else if (flag == kDynamicMem) {
|
||||
ptr = MallocDynamicMem(size, false);
|
||||
} else if (flag == kReuseDynamicMem) {
|
||||
auto key = node.get();
|
||||
auto iter = mem_reuse_util_ptr_->kernel_output_refs_.find(key);
|
||||
if (iter != mem_reuse_util_ptr_->kernel_output_refs_.end()) {
|
||||
// private member form KernelRuntime
|
||||
memreuse::KernelRefCountPtr kernel_ref_count_ptr = mem_reuse_util_ptr_->kernel_output_refs_[key][index];
|
||||
if (kernel_ref_count_ptr == nullptr) {
|
||||
return ptr;
|
||||
}
|
||||
ptr = reuse_mem_base_ + kernel_ref_count_ptr->offset_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs";
|
||||
}
|
||||
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
@ -609,7 +595,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
|
||||
void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) {
|
||||
if (cnode->inputs().size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2.";
|
||||
MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
|
||||
}
|
||||
auto pre_node = cnode->inputs()[1];
|
||||
// set clean output address
|
||||
|
@ -735,11 +721,11 @@ uint8_t *KernelRuntime::MallocDynamicMem(size_t size, bool communication_mem) {
|
|||
bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!LaunchKernelMod(*graph)) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed.";
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
}
|
||||
if (!SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed.";
|
||||
MS_LOG(ERROR) << "SyncStream failed!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -128,9 +128,6 @@ class KernelRuntime {
|
|||
size_t total_static_size_ = 0;
|
||||
size_t total_dynamic_size_ = 0;
|
||||
MemReuseUtilPtr mem_reuse_util_ptr_{nullptr};
|
||||
|
||||
private:
|
||||
uint8_t *reuse_mem_base_{nullptr};
|
||||
};
|
||||
using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
|
||||
} // namespace device
|
||||
|
|
|
@ -67,7 +67,7 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_
|
|||
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
||||
runtime_map_[runtime_key] = kernel_runtime;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "no kernel runtime creator for " << device_name << " with device id " << device_id;
|
||||
MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id;
|
||||
}
|
||||
|
||||
return kernel_runtime.get();
|
||||
|
|
|
@ -197,6 +197,23 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::string GetCNodeFuncName(const CNodePtr cnode) {
|
||||
if (cnode->inputs().empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
AnfNodePtr valuenode = cnode->input(0);
|
||||
if (valuenode->isa<ValueNode>()) {
|
||||
auto value = GetValueNode(valuenode);
|
||||
// check whether the valuenode is primitive
|
||||
if (value->isa<Primitive>()) {
|
||||
return value->cast<PrimitivePtr>()->name();
|
||||
}
|
||||
return value->ToString();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
|
||||
if (IsValueNode<Primitive>(node)) {
|
||||
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
|
||||
|
|
|
@ -83,7 +83,7 @@ class AnfVisitor;
|
|||
// Methods:
|
||||
// func_graph: return FuncGraph that this AnfNode belongs to.
|
||||
// scope: return the scope namespace of this AnfNode. Set it using set_scope.
|
||||
// abstract: return the cached inferred abstract value. It cantains type, shape
|
||||
// abstract: return the cached inferred abstract value. It contains type, shape
|
||||
// value. Set New cache using set_abstract.
|
||||
// intermediate_abstract: return the cached inferring abstract value.
|
||||
// Type/Shape: return the related info of this AnfNode. When this AnfNode is an
|
||||
|
@ -284,7 +284,7 @@ class Parameter : public ANode {
|
|||
};
|
||||
using ParameterPtr = std::shared_ptr<Parameter>;
|
||||
|
||||
// Value is used to represent the atomic expression metioned in BNF.
|
||||
// Value is used to represent the atomic expression mentioned in BNF.
|
||||
// It mainly be stored in ValueNode. Value and ValueNode is related definition.
|
||||
class Value : public Base {
|
||||
public:
|
||||
|
@ -313,7 +313,7 @@ using ValuePtr = std::shared_ptr<Value>;
|
|||
using ValuePtrList = std::vector<ValuePtr>;
|
||||
|
||||
// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
|
||||
// do not belong to any particular function graph.
|
||||
// does not belong to any particular function graph.
|
||||
class ValueNode : public ANode {
|
||||
public:
|
||||
explicit ValueNode(const ValuePtr &value) : value_(value) {}
|
||||
|
@ -384,6 +384,8 @@ static S GetValue(const ValuePtr &value) {
|
|||
return v;
|
||||
}
|
||||
|
||||
std::string GetCNodeFuncName(CNodePtr cnode);
|
||||
|
||||
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
|
||||
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value);
|
||||
|
||||
|
|
|
@ -34,19 +34,19 @@ bool Number::operator==(const Type& other) const {
|
|||
|
||||
Int::Int(const int nbits) : Number(IntBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
UInt::UInt(const int nbits) : Number(UIntBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 8 && nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) {
|
||||
if (nbits != 16 && nbits != 32 && nbits != 64) {
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ TypeId IntBitsToTypeId(const int nbits) {
|
|||
case 64:
|
||||
return kNumberTypeInt64;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ TypeId UIntBitsToTypeId(const int nbits) {
|
|||
case 64:
|
||||
return kNumberTypeUInt64;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
|
|||
case 64:
|
||||
return kNumberTypeFloat64;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "wrong number of bits.";
|
||||
MS_LOG(EXCEPTION) << "Wrong number of bits.";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -64,14 +64,14 @@ AbstractFunctionPtr FuncGraph::abstract() {
|
|||
for (auto& p : parameters_) {
|
||||
MS_EXCEPTION_IF_NULL(p);
|
||||
if (p->abstract() == nullptr) {
|
||||
MS_LOG(ERROR) << "error!!";
|
||||
MS_LOG(ERROR) << "Error!!";
|
||||
return nullptr;
|
||||
}
|
||||
args_spec_list.push_back(p->abstract());
|
||||
}
|
||||
|
||||
if (nullptr == output()) {
|
||||
MS_LOG(ERROR) << "error func graph no output";
|
||||
MS_LOG(ERROR) << "Error func graph no output";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -543,6 +543,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph,
|
|||
TraceManager::EndTrace();
|
||||
}
|
||||
}
|
||||
|
||||
bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>& kwarg_list) {
|
||||
// if the function does not have any vararg/kwarg/kwonly/default value/kw args input
|
||||
// return the original graph
|
||||
|
@ -556,6 +557,7 @@ bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr>&
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph,
|
||||
const std::vector<AnfNodePtr>& specialized_parameter_list,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr>* repl_nodes) {
|
||||
|
@ -664,7 +666,7 @@ void FuncGraph::EraseUnusedNodeInOrder() {
|
|||
auto mng = manager_.lock();
|
||||
if (mng) {
|
||||
auto nodes = mng->nodes()[shared_from_base<FuncGraph>()];
|
||||
// Erase unusued cnode.
|
||||
// Erase unused cnode.
|
||||
for (auto it = order_.begin(); it != order_.end();) {
|
||||
if (nodes.count(*it)) {
|
||||
(void)it++;
|
||||
|
@ -695,7 +697,7 @@ void FuncGraph::CheckOrder() {
|
|||
if (found == it) {
|
||||
DumpCNodeList();
|
||||
MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString()
|
||||
<< " doesn't obey the input denpency, "
|
||||
<< " doesn't obey the input dependency, "
|
||||
<< "as input " << input_node->DebugString() << " is not ahead of itself.";
|
||||
}
|
||||
}
|
||||
|
@ -842,5 +844,5 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr>& depend_inputs) {
|
|||
}
|
||||
|
||||
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
|
||||
const char kFuncGraphFlagUndetermin[] = "Undeterminate";
|
||||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -96,7 +96,7 @@ class FuncGraphBase : public Value {
|
|||
MS_DECLARE_PARENT(FuncGraphBase, Value);
|
||||
};
|
||||
|
||||
extern const char kFuncGraphFlagUndetermin[];
|
||||
extern const char kFuncGraphFlagUndetermined[];
|
||||
|
||||
class FuncGraph : public FuncGraphBase {
|
||||
public:
|
||||
|
@ -174,7 +174,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
GraphDebugInfoPtr debug_info();
|
||||
void set_debug_info(const GraphDebugInfoPtr &info) {
|
||||
if (info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "graph set null debug info";
|
||||
MS_LOG(EXCEPTION) << "Graph set null debug info";
|
||||
}
|
||||
this->debug_info_ = info;
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
// get all func graphs directly used by this func graph
|
||||
const FuncGraphCounterMap &func_graphs_used();
|
||||
|
||||
// get all func graphs nestedly used by this func graph
|
||||
// get all func graphs nested used by this func graph
|
||||
const FuncGraphSet &func_graphs_used_total();
|
||||
|
||||
// get all users of this func graph
|
||||
|
|
|
@ -817,7 +817,7 @@ void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
|||
void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
// possible chirld parent
|
||||
// possible child parent
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp);
|
||||
if (Mod(fg1, ParentProxy(fg2), direction)) {
|
||||
|
@ -1181,7 +1181,7 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt
|
|||
}
|
||||
path->add(fg);
|
||||
|
||||
// checkg if func graphs used contains J(func_graph);
|
||||
// check if func graphs used contains J(func_graph);
|
||||
auto& used = this->manager_->func_graphs_used();
|
||||
for (auto& item : used[fg]) {
|
||||
auto used_g = item.first;
|
||||
|
|
|
@ -650,7 +650,7 @@ class FuncGraphTransaction {
|
|||
explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
if (!manager_->IsManaged()) {
|
||||
MS_LOG(DEBUG) << "the manager is not managed yet";
|
||||
MS_LOG(DEBUG) << "The manager is not managed yet";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "device/device_address.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "pybind_api/export_flags.h"
|
||||
#include "pynative/pynative_execute.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -148,7 +148,7 @@ class MetaTensor : public Value {
|
|||
//
|
||||
// The constructed MetaTensor object has the same type and shape with meta_tensor.
|
||||
//
|
||||
// param meta_tensor An exisiting MetaTensor object.
|
||||
// param meta_tensor An existing MetaTensor object.
|
||||
virtual MetaTensor& operator=(const MetaTensor& meta_tensor);
|
||||
|
||||
// brief Compares two MetaTensor objects.
|
||||
|
@ -166,7 +166,7 @@ class MetaTensor : public Value {
|
|||
TypeId data_type() const { return data_type_; }
|
||||
std::string ToString() const override;
|
||||
std::string DumpText() const override;
|
||||
// bried Sets the data type of a tensor in its MetaTensor.
|
||||
// brief Sets the data type of a tensor in its MetaTensor.
|
||||
//
|
||||
// param data_type The data type of the tensor to be set.
|
||||
virtual TypeId set_data_type(const TypeId data_type) {
|
||||
|
@ -314,7 +314,7 @@ class Tensor : public MetaTensor {
|
|||
//
|
||||
// The constructed Tensor object has the same type and shape with tensor.
|
||||
//
|
||||
// param tensor An exisiting Tensor object.
|
||||
// param tensor An existing Tensor object.
|
||||
Tensor& operator=(const Tensor& tensor);
|
||||
|
||||
// brief Compares two Tensor objects.
|
||||
|
@ -383,7 +383,7 @@ class Tensor : public MetaTensor {
|
|||
// return The [TypeId] of the tensor data.
|
||||
TypeId GetDataType(const py::buffer_info& buf) const;
|
||||
|
||||
// bried Sets the data type of a tensor.
|
||||
// brief Sets the data type of a tensor.
|
||||
//
|
||||
// param data_type The data type of the tensor to be set.
|
||||
//
|
||||
|
|
|
@ -106,6 +106,27 @@ void Primitive::set_signatures(
|
|||
}
|
||||
}
|
||||
|
||||
std::string Primitive::GetAttrsText() const {
|
||||
if (attrs_.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "[";
|
||||
bool is_first = true;
|
||||
for (auto& attr : attrs_) {
|
||||
if (is_first) {
|
||||
is_first = false;
|
||||
} else {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << attr.first << "=" << attr.second->DumpText();
|
||||
}
|
||||
oss << "]";
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetBpropFunction() {
|
||||
static const char* const get_bprop_func_name = "get_bprop";
|
||||
if (py::hasattr(python_obj_, get_bprop_func_name)) {
|
||||
|
|
|
@ -102,6 +102,7 @@ class Primitive : public Named {
|
|||
|
||||
PrimType prim_type() const { return prim_type_; }
|
||||
std::string instance_name() const { return instance_name_; }
|
||||
std::string GetAttrsText() const;
|
||||
bool operator==(const Value& other) const override;
|
||||
bool operator==(const Primitive& other) const;
|
||||
~Primitive() override = default;
|
||||
|
|
|
@ -43,14 +43,13 @@ VisitFuncType AnfVisitor::Match(const PrimitivePtr &prim, const std::vector<opt:
|
|||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
// infact, funcs_size == inps_size - 1
|
||||
auto funcs_size = funcs.size();
|
||||
auto inps_size = inputs.size();
|
||||
auto inputs_size = inputs.size();
|
||||
|
||||
// check the inputs are matched with the predicate functions
|
||||
if (funcs_size > 0) {
|
||||
// use the predicate function list to check the number of inputs
|
||||
if (funcs_size != (inps_size - 1)) {
|
||||
if (funcs_size != (inputs_size - 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -63,7 +62,7 @@ VisitFuncType AnfVisitor::Match(const PrimitivePtr &prim, const std::vector<opt:
|
|||
}
|
||||
|
||||
// visit the inputs
|
||||
for (size_t i = 1; i < inps_size; i++) {
|
||||
for (size_t i = 1; i < inputs_size; i++) {
|
||||
this->Visit(inputs[i]);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -25,11 +25,10 @@ endif()
|
|||
|
||||
if(ENABLE_D)
|
||||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"akg/cce/*.cc"
|
||||
"tbe/*.cc"
|
||||
"aicpu/*.cc"
|
||||
"mng/*.cc"
|
||||
"hccl/*.cc"
|
||||
)
|
||||
target_sources(_mindspore_kernel_obj PRIVATE ${_D_SRC_LIST})
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -18,11 +18,11 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/ascend_kernel_mod.h"
|
||||
#include "kernel/aicpu/aicpu_util.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AicpuOpKernelMod : public KernelMod {
|
||||
class AicpuOpKernelMod : public AscendKernelMod {
|
||||
public:
|
||||
AicpuOpKernelMod();
|
||||
~AicpuOpKernelMod() override;
|
||||
|
|
|
@ -35,7 +35,6 @@
|
|||
#include "utils/convert_utils.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/utils.h"
|
||||
#include "transform/convert.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/akg/akg_kernel_attrs_process.h"
|
||||
|
||||
|
@ -240,8 +239,8 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
return true;
|
||||
}
|
||||
|
||||
void GetJson(const AnfNodePtr &anf_node, const vector<int> &dyn_input_sizes, const shared_ptr<OpAttr> &op_attr,
|
||||
nlohmann::json *const attr_json, const ValuePtr &attr_value) {
|
||||
void GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
|
||||
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(op_attr);
|
||||
MS_EXCEPTION_IF_NULL(attr_json);
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
using TaskInfoPtr = std::shared_ptr<ge::model_runner::TaskInfo>;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AscendKernelMod : public KernelMod {
|
||||
public:
|
||||
virtual std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, uint32_t) = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_
|
|
@ -19,7 +19,6 @@
|
|||
#include <map>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include "runtime/rt.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -490,7 +489,7 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info) {
|
|||
if (!filewrite.is_open()) {
|
||||
return;
|
||||
}
|
||||
filewrite << info << endl;
|
||||
filewrite << info << std::endl;
|
||||
filewrite.close();
|
||||
if (nullptr == realpath(path.c_str(), real_path)) {
|
||||
MS_LOG(DEBUG) << "dir " << path << " does not exit.";
|
||||
|
|
|
@ -226,12 +226,12 @@ class LstmGpuKernel : public GpuKernel {
|
|||
size_t reserved_size_;
|
||||
|
||||
// input desc
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> x_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> x_desc_;
|
||||
cudnnTensorDescriptor_t hx_desc_;
|
||||
cudnnTensorDescriptor_t cx_desc_;
|
||||
cudnnFilterDescriptor_t w_desc_;
|
||||
cudnnDropoutDescriptor_t dropout_desc_;
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
cudnnTensorDescriptor_t hy_desc_;
|
||||
cudnnTensorDescriptor_t cy_desc_;
|
||||
cudnnRNNDescriptor_t rnn_desc_;
|
||||
|
|
|
@ -258,8 +258,8 @@ class LstmGradDataGpuKernel : public GpuKernel {
|
|||
cudnnRNNDescriptor_t rnn_desc_;
|
||||
|
||||
// input desc
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> dy_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> dy_desc_;
|
||||
cudnnTensorDescriptor_t dhy_desc_;
|
||||
cudnnTensorDescriptor_t dcy_desc_;
|
||||
cudnnFilterDescriptor_t w_desc_;
|
||||
|
@ -269,7 +269,7 @@ class LstmGradDataGpuKernel : public GpuKernel {
|
|||
cudnnDropoutDescriptor_t dropout_desc_;
|
||||
|
||||
// output desc
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> dx_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> dx_desc_;
|
||||
cudnnTensorDescriptor_t dhx_desc_;
|
||||
cudnnTensorDescriptor_t dcx_desc_;
|
||||
|
||||
|
|
|
@ -214,9 +214,9 @@ class LstmGradWeightGpuKernel : public GpuKernel {
|
|||
cudnnDropoutDescriptor_t dropout_desc_;
|
||||
|
||||
// input desc
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> x_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> x_desc_;
|
||||
cudnnTensorDescriptor_t hx_desc_;
|
||||
unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
std::unique_ptr<cudnnTensorDescriptor_t[]> y_desc_;
|
||||
|
||||
// output desc
|
||||
cudnnFilterDescriptor_t dw_desc_;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue