Compare commits

...

99 Commits
master ... r1.4

Author SHA1 Message Date
i-robot 1f96911490 !26315 [Resnet] revert to PR 24604
Merge pull request !26315 from zhouneng/revert_pr_24604
2021-11-16 12:05:25 +00:00
zhouneng2 509e0e51fb revert PR 24604 2021-11-15 16:17:22 +08:00
i-robot 3e8e59569a !25420 Fix deepspeech link
Merge pull request !25420 from wanyiming/r1.4_dp2
2021-10-28 04:44:00 +00:00
wanyiming 809d8b3ce3 fix deepspeech link 2021-10-26 15:03:39 +08:00
jjfeing 4aa34a7473 !25360 fix release md r1.4
Merge pull request !25360 from jjfeing/r1.4
2021-10-25 02:21:42 +00:00
jjfeing 3e2e4571ea fix release md 2021-10-25 10:20:22 +08:00
jjfeing 5074a00911 !25307 update release r1.4
Merge pull request !25307 from jjfeing/r1.4
2021-10-22 06:33:37 +00:00
jjfeing 9a61ce075e updata release r1.4 2021-10-22 14:32:18 +08:00
i-robot 4026500fa9 !24604 rename keyword argument from 'acc_level' to 'boost_level' in class Model.__init__
Merge pull request !24604 from zhouneng/fix_issue_acc_level_rename_r1.4
2021-10-15 02:32:57 +00:00
zhouneng2 cb561323e1 rename keyword argument from 'acc_level' to 'boost_level' in class Model.__init__ 2021-10-11 16:37:19 +08:00
i-robot 31068de695 !24402 modify fasttext run scripts for r1.4
Merge pull request !24402 from zhanghuiyao/r1.4_fasttext
2021-09-30 08:58:26 +00:00
zhanghuiyao dae78c7430 modify fasttext run scripts 2021-09-29 15:09:49 +08:00
i-robot 683439f0dd !23047 23046 fix offline debugger read tensor and online debugger overflow and check watchpoint bug r1.4
Merge pull request !23047 from parastooashtari/overflow_bug_1.4
2021-09-16 14:04:25 +00:00
i-robot e5c0b80492 !22892 Fix convert async dump files problem in r1.4 branch
Merge pull request !22892 from TinaMengtingZhang/convert_async_bug_r1_4
2021-09-16 14:04:21 +00:00
Parastoo Ashtari ed837dea58 fixed overflow, offline debugger and online debugger check wp bug 2021-09-15 15:11:05 -04:00
TinaMengtingZhang 1379a829ce Fix convert async dump files failed issue and refactor convert_async.py
Comments addressed

Adapt async dump conversion tool with new run package (Sep03 version)
2021-09-14 14:51:58 -04:00
i-robot 9baffac2d3 !23247 [MSLITE] Fix bug of tf model‘s output names.
Merge pull request !23247 from wangshaocong/convert_r1.4
2021-09-11 01:44:43 +00:00
wang_shaocong 87174d9561 [MSLITE] Convert a tf model whose input is also an output. 2021-09-10 17:27:20 +08:00
i-robot 32658645b5 !22849 Fix issue for multi card check watchpoint r1.4
Merge pull request !22849 from parastooashtari/offline_dbg_bug_1.4
2021-09-08 11:43:10 +00:00
i-robot 9af75bb1e7 !23038 Profiler deal with the case that cpu utilization data may not have information for each pipeline op.
Merge pull request !23038 from casgj/r1.4
2021-09-08 09:24:49 +00:00
gaojing 4148cfc0c3 Profiler deal with the case that the cpu utilization data may not have information for each pipeline op. 2021-09-07 09:03:33 -04:00
Parastoo Ashtari a7e977a25e fix checkwatchpoints multi card issue 2021-09-03 12:16:44 -04:00
i-robot 9c570e316e !22641 fix yolov3 tiny multi root graph seg fault
Merge pull request !22641 from john_tzanakakis/1.4
2021-09-03 03:17:44 +00:00
i-robot 0fe6e20502 !22724 Fix for the set value of monitoring point is inconsistent with the actual value issue r1.4
Merge pull request !22724 from parastooashtari/offline_dbg_bug_1.4
2021-09-01 07:35:12 +00:00
Parastoo Ashtari d76d1df56d fix inconsistent wp value issue in offline dbg 2021-08-31 17:00:47 -04:00
John Tzanakakis e0c4f0bfb1 fix yolov3 tiny multi root graph seg fault 2021-08-30 18:33:31 -04:00
i-robot f290098e3d !22480 modify the limit of axis of reduce ops
Merge pull request !22480 from Simson/push-to-r1.4
2021-08-27 13:10:13 +00:00
simson 52c5f9befe modify the limit of axis of reduce ops 2021-08-27 16:46:39 +08:00
i-robot 43f5ec7241 !22415 fix bug on maskrcnn/fasterrcnn
Merge pull request !22415 from gengdongjie/r1.4
2021-08-26 13:39:50 +00:00
gengdongjie 5d7a31fc9d Fix the execution sequence problem of the load in maketuple 2021-08-26 15:36:13 +08:00
i-robot cec2fea2ee !22234 [MSLITE] Fix ci test.
Merge pull request !22234 from wangshaocong/bugfix_r1.4
2021-08-24 09:49:46 +00:00
wang_shaocong 907c2553d0 [MSLITE]fix ci test 2021-08-23 15:41:50 +08:00
i-robot 3e9b8f23d1 !22119 [MS][LITE][r1.4] fix entrance gurad problem
Merge pull request !22119 from XianglongZeng/r1.4
2021-08-23 07:24:48 +00:00
zengxianglong fa3a07cd8b fix entrance guard problem 2021-08-21 16:51:21 +08:00
i-robot fe435e296e !22125 fix uint8 overflow bug
Merge pull request !22125 from yuchaojie/r1.4
2021-08-21 06:15:56 +00:00
i-robot 215d258411 !21866 Ascend control use vm r1.4
Merge pull request !21866 from chenfei_mindspore/ascend-control-use-vm-r1.4
2021-08-21 06:04:28 +00:00
i-robot 42429c5ecb !22103 Add the download link of vgg19-0-97_5004.ckpt to openpose for r1.4
Merge pull request !22103 from zhanghuiyao/fix_openpose_readme_r1.4
2021-08-20 07:56:35 +00:00
yuchaojie 0b778d936f fix uint8 overflow bug 2021-08-20 15:28:32 +08:00
i-robot 0713e47db1 !22095 redirect model input bin
Merge pull request !22095 from hangq/graph_inout
2021-08-20 06:21:19 +00:00
chenfei 5b1639ad56 ascend control use vm 2021-08-20 12:57:52 +08:00
zhanghuiyao 5d8bf2b56f Add vgg download link to openpose readme.md for r1.4 2021-08-20 09:44:03 +08:00
hangangqiang 9e099f8497 redirect model input bin 2021-08-20 09:05:24 +08:00
i-robot aecb280e17 !22064 fix graph input
Merge pull request !22064 from hangq/graph_inout
2021-08-19 11:11:15 +00:00
i-robot b7fd96b182 !21876 reduce func graph number by elimiate incorporate getitem
Merge pull request !21876 from xychow/reduce-graph-number-by-eliminate-incorporate-getitem
2021-08-19 09:29:08 +00:00
hangangqiang 898e63ea6f fix graph input 2021-08-19 16:24:43 +08:00
i-robot cd14d06971 !21968 use temporary dir as dump dir in testcase
Merge pull request !21968 from yelihua/ylh_r14
2021-08-19 03:40:27 +00:00
yelihua b4c82be639 use temporary dir as dump dir 2021-08-17 20:53:07 +08:00
i-robot c35d32d45a !21867 update the method for get rank_id
Merge pull request !21867 from yelihua/ylh_r14
2021-08-17 11:25:37 +00:00
yelihua 425b1168ad get rank id when set hccl env for single card train 2021-08-17 12:10:45 +08:00
i-robot f40c33ff18 !21856 Fix comments in LossBase.get_loss()
Merge pull request !21856 from chenhaozhe/code_docs_change_lossBase_getloss_1.4
2021-08-17 03:13:48 +00:00
i-robot 780f40395a !21885 [MSLITE] Fix bug of tf model‘s output names.
Merge pull request !21885 from wangshaocong/convert_name_r1.4
2021-08-16 14:40:17 +00:00
wang_shaocong 8f60482bfd [MSLITE] Fix bug of tf model parser. 2021-08-16 20:31:18 +08:00
i-robot c1a90b0870 !21844 [EmbeddingLookup] Add a check for the case of scalar indices
Merge pull request !21844 from Xiaoda/83-add-a-check-for-embeddinglookup-scalar-indices-r.14
2021-08-16 10:57:09 +00:00
zhousiyi b798586988 keep tuple_getitem as much as possible to reduce the number of func graphs
shrink output of func_graph other than set unused to dead node
2021-08-16 07:08:07 +00:00
chenhaozhe 209b2f1e04 fix comments for LossBase.get_loss 2021-08-16 11:28:29 +08:00
Xiaoda Zhang 9a27864e42 add a check for scalar indices for embeddinglookup 2021-08-16 10:30:43 +08:00
i-robot 44268e9ad8 !21810 [MSLITE] Modify output tensor names of ms model.
Merge pull request !21810 from wangshaocong/convert_name_r1.4
2021-08-14 06:54:29 +00:00
wang_shaocong cdf9989211 [MSLITE] Modify output tensor names of ms model 2021-08-14 01:09:20 +08:00
i-robot b181342f9d !21666 fix yolov3_darknet53 310 infer fail bug in r1.4
Merge pull request !21666 from zhanghuiyao/r1.4
2021-08-12 07:08:53 +00:00
i-robot 0f298d22f9 !21711 update Matmul docs for its wrong format
Merge pull request !21711 from dinglinhe/code_docs_dlh_r1.4_I440D3_1
2021-08-12 07:05:30 +00:00
dinglinhe 6a4743441e update Matmul docs for its wrong format 2021-08-12 11:42:45 +08:00
i-robot eb23047515 !21692 nnie branch adapts r1.4
Merge pull request !21692 from jianghui58/r1.4_fix
2021-08-12 03:23:12 +00:00
i-robot eb1472d129 !21510 update md5 for json
Merge pull request !21510 from liubuyu/r1.4
2021-08-12 01:54:35 +00:00
jianghui58 6b12f40d3e nnie branch adapts r1.4 2021-08-12 09:53:19 +08:00
i-robot 6efa88c486 !21652 fix bugs: DestroyHccl must be called before FreeDeviceMemory
Merge pull request !21652 from jjfeing/r1.4
2021-08-11 11:59:51 +00:00
zhanghuiyao 80caa2a715 fix yolov3_darknet53 310 eval fail bug in r1.4 2021-08-11 16:45:26 +08:00
jjfeing 8192c0034a DestroyHccl must be called before FreeDeviceMemory 2021-08-11 12:45:04 +08:00
lby 20fb63b4a5 update md5 for json when download from gitee 2021-08-11 11:42:48 +08:00
zhangzhenghai fb2b4f0617 !21577 fix event insertion bug
Merge pull request !21577 from fix-stream-1.4
2021-08-10 06:57:38 +00:00
i-robot 76616c83ff !21526 optional init hccl for graph mode and pynative mode
Merge pull request !21526 from zhoufeng/optional-init-hccl-r1.4
2021-08-10 02:03:54 +00:00
i-robot 75e8783495 !21568 parallel_train_and_eval_fix
Merge pull request !21568 from yao_yf/parallel_train_and_eval_fix
2021-08-10 01:44:23 +00:00
hwjiaorui dd3191026f fix stream find target 2021-08-09 20:11:02 +08:00
zhoufeng 68ffbc7c55 optional init hccl for graph mode and pynative mode
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
2021-08-09 19:10:31 +08:00
yao_yf b39f02ed2d pangu train and eval 2021-08-09 17:41:53 +08:00
i-robot 7ab5ba3a97 !21384 dynamic shape interface adapter
Merge pull request !21384 from liubuyu/r1.4
2021-08-06 12:40:14 +00:00
i-robot 0d1ab208b2 !21466 api docs for doxygen
Merge pull request !21466 from zhoufeng/code_docs_api_note_1.4
2021-08-06 08:08:24 +00:00
i-robot a1d1b691fd !21425 Add skip decorator and gpu testcase
Merge pull request !21425 from gaoyong10/runtime1.4_001
2021-08-06 07:53:26 +00:00
i-robot 736db236c4 !21436 Recover the environment variable MS_DEBUGGER_HOST
Merge pull request !21436 from maning202007/r1.4
2021-08-06 07:21:51 +00:00
zhoufeng e2111f9fdd api docs
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
2021-08-06 11:32:37 +08:00
lby 2cdf231575 dynamic shape interface adapter 2021-08-06 11:24:29 +08:00
maning202007 4be1d3489f revert 9cdd3bbf
Recover the environment variable MS_DEBUGGER_HOST

Update the version number to 1.4.0 in debugger
2021-08-06 10:51:15 +08:00
gaoyong10 47efd5f9d1 add skip decorator and gpu testcase 2021-08-05 17:51:12 +08:00
i-robot c1d65a3f76 !21385 fix same node is used by two comm op
Merge pull request !21385 from zhoufeng/xiu-ba-ge
2021-08-05 02:17:14 +00:00
i-robot 10f8936f77 !21304 fix parse aicpu file error
Merge pull request !21304 from yanghaitao/yht_fix_aicpu_parser_error
2021-08-04 07:29:05 +00:00
i-robot 7383457756 !21205 [MS][LITE][TOD] memory optimization
Merge pull request !21205 from yonibaehr/export_yoni
2021-08-04 07:19:25 +00:00
i-robot c77fbcddc3 !21319 [MSLITE] Expanding dimension of pad from 4 to 6.
Merge pull request !21319 from wangshaocong/expand_pad
2021-08-04 06:30:51 +00:00
i-robot 01639b95a1 !21034 Add to TB-Net to model_zoo/official/recommend
Merge pull request !21034 from lihaoyang/tbnet
2021-08-04 06:06:12 +00:00
i-robot d260f38936 !15722 [assistant][AdjustGamma]
Merge pull request !15722 from QingfengLi/adjust_gamma
2021-08-04 03:58:09 +00:00
i-robot a3155d16ed !21128 Print erro info while file open failed
Merge pull request !21128 from zhangzhaoju/master_file_open_error
2021-08-04 03:41:54 +00:00
i-robot 770d2a994e !20767 add CPU L2Loss op
Merge pull request !20767 from 范吉斌/master_l2loss
2021-08-04 03:26:29 +00:00
lihaoyang 9b9dc5ee94 add tbnet for open source in model_zoo/official/recommend 2021-08-04 11:05:41 +08:00
i-robot 1245eed4d1 !21283 VM bug fix and test_cont_case add gpu
Merge pull request !21283 from chenfei_mindspore/vm-bug-fix
2021-08-04 02:44:04 +00:00
chenx2ovo 4191dde45f [feat][assistant][I3CEGM] add op adjust gamma 2021-08-04 02:32:13 +08:00
yoni 208c620cea memory optimization 2021-08-03 16:42:58 +03:00
fanjibin 9e5618a5b8 add CPU l2loss op 2021-08-03 20:19:15 +08:00
zhangzhaoju 610ee46ebe Add file open failure info while file open failed 2021-08-03 19:52:38 +08:00
chenfei 709b1e80db add control test cases and vm bug fix 2021-08-03 18:43:22 +08:00
wang_shaocong cb560406ea [MSLITE] expanding dimension of pad from 4 to 6. 2021-08-03 18:29:17 +08:00
yanghaitao1 c1428a8af1 fix aicpu parser error 2021-08-03 03:02:14 -04:00
217 changed files with 6547 additions and 1651 deletions

View File

@ -1,3 +1,107 @@
# MindSpore 1.4.0
## MindSpore 1.4.0 Release Notes
### Major Features and Improvements
#### NewModels
#### FrontEnd
#### Auto Parallel
- Add distributed operators: Conv2D/Conv2DTranspose/Conv2DBackpropInput/MaxPool/AvgPool/BatchNorm/GatherD
- Support to configure shard strategy for dataset
#### Executor
#### DataSet
- Add SlicePatchesOperation for Remote Sensing feature[!18179](https://e.gitee.com/mind_spore/repos/mindspore/mindspore/pulls/18179)
#### FederatedLearning
#### Running Data Recorder
#### GraphKernel Fusion
#### Profiler
- [STABLE] Support MS_DIAGNOSTIC_DATA_PATH for profiler feature.(Ascend/GPU)
#### Dump
- [STABLE] Support MS_DIAGNOSTIC_DATA_PATH for dump feature.(Ascend/GPU/CPU)
### API Change
#### Backwards Incompatible Change
##### Python API
##### Command Line Interface
###### Dump Config
Previously, we need to set the dump path in dump config file. To make the dump feature easier to use on cloud, we support new environment parameter `MS_DIAGNOSTIC_DATA_PATH`.
| 1.3.0 | 1.4.0 |
| ------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------- |
| `path` is a mandatory field. | `path` field is optional. If `path` field is not provided or is empty string, `MS_DIAGNOSTIC_DATA_PATH` should be set in environment. |
### Bug fixes
#### FrontEnd
#### Executor
#### Dataset
- Fix module 'signal' has no attribute 'SIGCHLD' problem under windows platform. ([!21232](https://gitee.com/mindspore/mindspore/pulls/21232))
## MindSpore Lite
### Major Features and Improvements
#### Converter and runtime
#### x86 backend optimization
#### ARM backend optimization
#### Cuda backend optimization
#### OpenCL backend
#### Post quantization
#### Training on Device
#### Codegen
### API Change
#### API Incompatible Change
##### C++ API
#### New features
##### Java API
### Bug fixes
#### Deprecations
### Contributors
Thanks goes to these wonderful people:
Adel, AGroupofProbiotocs, anthonyaje, anzhengqi, askmiao, baihuawei, baiyangfan, bai-yangfan, bingyaweng, BowenK, buxue, caifubi, CaoJian, caojian05, caozhou, Cathy, changzherui, chenbo116, chenfei, chengxianbin, chenhaozhe, chenjianping, chenzomi, chenzupeng, chujinjin, cj, cjh9368, Corleone, damon0626, danish, Danish, davidmc, dayschan, doitH, dong-li001, eric, Eric, fary86, fuzhiye, Gaoxiong, GAO_HYP_XYJ, gengdongjie, Gogery, gongdaguo, gray0v0, gukecai, guoqi, gzhcv, hangq, hanhuifeng2020, Harshvardhan, He, heleiwang, hexia, Hoai, HuangBingjian, huangdongrun, huanghui, huangxinjing, huqi, huzhifeng, hwjiaorui, Islam Amin, Jesse, , Jiabin Liu, jianghui58, jiangzhiwen, Jiaqi, jin-xiulang, jinyaohui, jjfeing, John, Jonathan, jonyguo, JulyAi, jzg, kai00, kingfo, kingxian, kpy, kswang, laiyongqiang, leonwanghui, Li, liangchenghui, liangzelang, lichen_101010, lichenever, lihongkang, lilei, limingqi107, ling, linqingke, Lin Xh, liubuyu, liuwenhao4, liuxiao78, liuxiao93, liuyang_655, liuzhongkai, Lixia, lixian, liyanliu, liyong, lizhenyu, luopengting, luoyang, lvchangquan, lvliang, lz, mahdi, Mahdi, maning202007, Margaret_wangrui, mayang, mengyuanli, Ming_blue, nhussain, ougongchang, panfengfeng, panyifeng, Payne, Peilin, peixu_ren, Pengyongrong, qianlong, qianjiahong, r1chardf1d0, riemann_penn, rmdyh, Sheng, shenwei41, simson, Simson, Su, sunsuodong, tao_yunhao, tinazhang, VectorSL, , Wan, wandongdong, wangdongxu, wangmin, wangnan39@huawei.com, wangyue01, wangzhe, wanyiming, Wei, wenchunjiang, wilfChen, WilliamLian, wsc, wudenggang, wukesong, wuweikang, wuxuejian, Xiao Tianci, Xiaoda, xiefangqi, xinyunfan, xuanyue, xulei2020, Xun, xuyongfei, yanghaitao, yanghaitao1, yanghaoran, YangLuo, yangruoqi713, yankai, yanzhenxiang2020, yao_yf, yepei6, yeyunpeng, Yi, yoni, yoonlee666, yuchaojie, yujianfeng, yuximiao, zengzitao, Zhang, zhanghaibo5@huawei.com, zhanghuiyao, zhanghui_china, zhangxinfeng3, zhangyihui, zhangz0911gm, zhanke, zhanyuan, zhaodezan, zhaojichen, zhaoting, zhaozhenlong, zhengjun10, Zhenglong Li, zhiqwang, zhoufeng, zhousiyi, zhouyaqiang, zhouyifengCode, Zichun, Zirui, Ziyan, zjun, ZPaC, wangfengwfwf, zymaa, gerayking.
Contributions of any kind are welcome!
# MindSpore 1.3.0 # MindSpore 1.3.0
## MindSpore 1.3.0 Release Notes ## MindSpore 1.3.0 Release Notes

View File

@ -9,7 +9,7 @@ endif()
if(ENABLE_GITEE) if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") set(MD5 "36ea0d9a709c6667b2798a62f6b197ae")
set(INCLUDE "./include") set(INCLUDE "./include")
else() else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")

View File

@ -38,12 +38,19 @@ class Allocator;
class Delegate; class Delegate;
class DeviceInfoContext; class DeviceInfoContext;
/// \brief Context is used to store environment variables during execution.
class MS_API Context { class MS_API Context {
public: public:
Context(); Context();
~Context() = default; ~Context() = default;
/// \brief Set the number of threads at runtime. This option is only valid for MindSpore Lite.
///
/// \param[in] thread_num the number of threads at runtime.
void SetThreadNum(int32_t thread_num); void SetThreadNum(int32_t thread_num);
/// \brief Get the current thread number setting.
///
/// \return The current thread number setting.
int32_t GetThreadNum() const; int32_t GetThreadNum() const;
/// \brief Set the thread affinity to CPU cores. /// \brief Set the thread affinity to CPU cores.
@ -60,6 +67,10 @@ class MS_API Context {
void SetDelegate(const std::shared_ptr<Delegate> &delegate); void SetDelegate(const std::shared_ptr<Delegate> &delegate);
std::shared_ptr<Delegate> GetDelegate() const; std::shared_ptr<Delegate> GetDelegate() const;
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
/// heterogeneous scenarios with multiple members in the vector.
///
/// \return Mutable reference of DeviceInfoContext vector in this context.
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
private: private:
@ -67,14 +78,24 @@ class MS_API Context {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief DeviceInfoContext defines different device contexts.
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> { class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
public: public:
struct Data; struct Data;
DeviceInfoContext(); DeviceInfoContext();
virtual ~DeviceInfoContext() = default; virtual ~DeviceInfoContext() = default;
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
virtual enum DeviceType GetDeviceType() const = 0; virtual enum DeviceType GetDeviceType() const = 0;
/// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
/// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
///
/// \param T Type
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
template <class T> template <class T>
std::shared_ptr<T> Cast() { std::shared_ptr<T> Cast() {
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
@ -98,27 +119,60 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
/// for MindSpore Lite.
class MS_API CPUDeviceInfo : public DeviceInfoContext { class MS_API CPUDeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
/// \brief Set enables to perform the float16 inference
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16); void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const; bool GetEnableFP16() const;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
/// for MindSpore Lite.
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
/// \brief Set the NPU frequency.
///
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
/// performance), default as 3.
void SetFrequency(int frequency); void SetFrequency(int frequency);
/// \brief Get the NPU frequency.
///
/// \return NPU frequency
int GetFrequency() const; int GetFrequency() const;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
class MS_API GPUDeviceInfo : public DeviceInfoContext { class MS_API GPUDeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id); void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const; uint32_t GetDeviceID() const;
void SetGpuTrtInferMode(bool gpu_trt_infer_mode); void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
@ -127,8 +181,15 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext {
inline void SetPrecisionMode(const std::string &precison_mode); inline void SetPrecisionMode(const std::string &precison_mode);
inline std::string GetPrecisionMode() const; inline std::string GetPrecisionMode() const;
/// \brief Set enables to perform the float16 inference
///
/// \param[in] is_fp16 Enable float16 inference or not.
void SetEnableFP16(bool is_fp16); void SetEnableFP16(bool is_fp16);
/// \brief Get enables to perform the float16 inference
///
/// \return Whether enable float16 inference.
bool GetEnableFP16() const; bool GetEnableFP16() const;
private: private:
void SetPrecisionMode(const std::vector<char> &precision_mode); void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const; std::vector<char> GetPrecisionModeChar() const;
@ -139,52 +200,113 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
} }
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
/// invalid for MindSpore Lite.
class MS_API Ascend910DeviceInfo : public DeviceInfoContext { class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id); void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const; uint32_t GetDeviceID() const;
}; };
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
/// invalid for MindSpore Lite.
class MS_API Ascend310DeviceInfo : public DeviceInfoContext { class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
public: public:
/// \brief Get the type of this DeviceInfoContext.
///
/// \return Type of this DeviceInfoContext.
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
/// \brief Set device id.
///
/// \param[in] device_id The device id.
void SetDeviceID(uint32_t device_id); void SetDeviceID(uint32_t device_id);
/// \brief Get the device id.
///
/// \return The device id.
uint32_t GetDeviceID() const; uint32_t GetDeviceID() const;
inline void SetDumpConfigPath(const std::string &cfg_path); inline void SetDumpConfigPath(const std::string &cfg_path);
inline std::string GetDumpConfigPath() const; inline std::string GetDumpConfigPath() const;
// aipp config file /// \brief Set AIPP configuration file path.
///
/// \param[in] cfg_path AIPP configuration file path.
inline void SetInsertOpConfigPath(const std::string &cfg_path); inline void SetInsertOpConfigPath(const std::string &cfg_path);
/// \brief Get AIPP configuration file path.
///
/// \return AIPP configuration file path.
inline std::string GetInsertOpConfigPath() const; inline std::string GetInsertOpConfigPath() const;
// nchw or nhwc /// \brief Set format of model inputs.
///
/// \param[in] format Optional "NCHW", "NHWC", etc.
inline void SetInputFormat(const std::string &format); inline void SetInputFormat(const std::string &format);
/// \brief Get format of model inputs.
///
/// \return The format of model inputs.
inline std::string GetInputFormat() const; inline std::string GetInputFormat() const;
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1" /// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
inline void SetInputShape(const std::string &shape); inline void SetInputShape(const std::string &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
inline std::string GetInputShape() const; inline std::string GetInputShape() const;
/// \brief Set shape of model inputs.
///
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
/// shape 4,3,2,1.
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape); void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
/// \brief Get shape of model inputs.
///
/// \return The shape of model inputs.
std::map<int, std::vector<int>> GetInputShapeMap() const; std::map<int, std::vector<int>> GetInputShapeMap() const;
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size); void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
inline std::string GetDynamicBatchSize() const; inline std::string GetDynamicBatchSize() const;
// FP32, UINT8 or FP16, default as FP32 /// \brief Set type of model outputs.
///
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
void SetOutputType(enum DataType output_type); void SetOutputType(enum DataType output_type);
/// \brief Get type of model outputs.
///
/// \return The set type of model outputs.
enum DataType GetOutputType() const; enum DataType GetOutputType() const;
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" /// \brief Set precision mode of model.
///
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
/// "allow_mix_precision", "force_fp16" is set as default
inline void SetPrecisionMode(const std::string &precision_mode); inline void SetPrecisionMode(const std::string &precision_mode);
/// \brief Get precision mode of model.
///
/// \return The set type of model outputs
inline std::string GetPrecisionMode() const; inline std::string GetPrecisionMode() const;
// Optional "high_performance" and "high_precision", "high_performance" is set as default /// \brief Set op select implementation mode.
///
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
/// default.
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
/// \brief Get op select implementation mode.
///
/// \return The set op select implementation mode.
inline std::string GetOpSelectImplMode() const; inline std::string GetOpSelectImplMode() const;
inline void SetFusionSwitchConfigPath(const std::string &cfg_path); inline void SetFusionSwitchConfigPath(const std::string &cfg_path);

View File

@ -37,32 +37,75 @@ class Metrics;
namespace dataset { namespace dataset {
class Dataset; class Dataset;
} // namespace dataset } // namespace dataset
/// \brief The Model class is used to define a MindSpore model, facilitating computational graph management.
class MS_API Model { class MS_API Model {
public: public:
Model(); Model();
~Model(); ~Model();
Model(const Model &) = delete; Model(const Model &) = delete;
void operator=(const Model &) = delete; void operator=(const Model &) = delete;
/// \brief Builds a model so that it can run on a device.
///
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
/// from Graph, for example, model.Build(GraphCell(graph), context).
/// \param[in] model_context A context used to store options during execution.
/// \param[in] train_cfg A config used by training.
///
/// \return Status.
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr, Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
const std::shared_ptr<TrainCfg> &train_cfg = nullptr); const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
/// \brief Resizes the shapes of inputs.
///
/// \param[in] inputs A vector that includes all input tensors in order.
/// \param[in] dims Defines the new shapes of inputs, should be consistent with inputs.
///
/// \return Status.
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims); Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
/// \brief Inference model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
/// \brief Obtains all input tensors of the model.
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs(); std::vector<MSTensor> GetInputs();
/// \brief Obtains the input tensor of the model by name.
///
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetInputByTensorName(const std::string &tensor_name); inline MSTensor GetInputByTensorName(const std::string &tensor_name);
Status InitMetrics(std::vector<Metrics *> metrics); Status InitMetrics(std::vector<Metrics *> metrics);
std::vector<Metrics *> GetMetrics(); std::vector<Metrics *> GetMetrics();
/// \brief Obtains all output tensors of the model.
///
/// \return The vector that includes all output tensors.
std::vector<MSTensor> GetOutputs(); std::vector<MSTensor> GetOutputs();
/// \brief Obtains names of all output tensors of the model.
///
/// \return A vector that includes names of all output tensors.
inline std::vector<std::string> GetOutputTensorNames(); inline std::vector<std::string> GetOutputTensorNames();
/// \brief Obtains the output tensor of the model by name.
///
/// \return The output tensor with the given name, if the name is not found, an invalid tensor is returned.
inline MSTensor GetOutputByTensorName(const std::string &tensor_name); inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name); inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name);
/// \brief Inference model.
///
/// \param[in] device_type Device typeoptions are kGPU, kAscend910, etc.
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
///
/// \return Is supported or not.
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type); static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
Status SetTrainMode(bool train); Status SetTrainMode(bool train);

View File

@ -27,13 +27,43 @@
#include "include/api/dual_abi_helper.h" #include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
/// \brief The Serialization class is used to summarize methods for reading and writing model files.
class MS_API Serialization { class MS_API Serialization {
public: public:
/// \brief Loads a model file from memory buffer.
///
/// \param[in] model_data A buffer filled by model file.
/// \param[in] data_size The size of the buffer.
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
/// \param[out] graph The output parameter, an object saves graph data.
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
///
/// \return Status.
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
/// \brief Loads a model file from path, is not supported on MindSpore Lite.
///
/// \param[in] file The path of model file.
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
/// \param[out] graph The output parameter, an object saves graph data.
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
///
/// \return Status.
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {}, inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm); const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load multiple models from multiple files, MindSpore Lite does not provide this feature.
///
/// \param[in] files The path of model files.
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
/// \param[out] graph The output parameter, an object saves graph data.
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
///
/// \return Status.
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs, inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);

View File

@ -64,18 +64,64 @@ struct QuantParam {
}; };
class Allocator; class Allocator;
/// \brief The MSTensor class defines a tensor in MindSpore.
class MS_API MSTensor { class MS_API MSTensor {
public: public:
class Impl; class Impl;
/// \brief Creates a MSTensor object, whose data need to be copied before accessed by Model, must be used in pairs
/// with DestroyTensorPtr.
///
/// \param[in] name The name of the MSTensor.
/// \param[in] type The data type of the MSTensor.
/// \param[in] shape The shape of the MSTensor.
/// \param[in] data The data pointer that points to allocated memory.
/// \param[in] data_len The length of the memory, in bytes.
///
/// \return A pointer of MSTensor.
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept; const void *data, size_t data_len) noexcept;
/// \brief Creates a MSTensor object, whose data can be directly accessed by Model, must be used in pairs with
/// DestroyTensorPtr.
///
/// \param[in] name The name of the MSTensor.
/// \param[in] type The data type of the MSTensor.
/// \param[in] shape The shape of the MSTensor.
/// \param[in] data The data pointer that points to allocated memory.
/// \param[in] data_len The length of the memory, in bytes.
///
/// \return A pointer of MSTensor.
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept; const void *data, size_t data_len) noexcept;
/// \brief Creates a MSTensor object, whose device data can be directly accessed by Model, must be used in pairs with
/// DestroyTensorPtr.
///
/// \param[in] name The name of the MSTensor.
/// \param[in] type The data type of the MSTensor.
/// \param[in] shape The shape of the MSTensor.
/// \param[in] data The data pointer that points to device memory.
/// \param[in] data_len The length of the memory, in bytes.
///
/// \return A pointer of MSTensor.
static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept; const void *data, size_t data_len) noexcept;
/// \brief Create a string type MSTensor object whose data can be accessed by Model only after being copied, must be
/// used in pair with DestroyTensorPtr.
///
/// \param[in] name The name of the MSTensor.
/// \param[in] str A vector container containing several strings.
///
/// \return A pointer of MSTensor.
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str); static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
/// \brief Parse the string type MSTensor object into strings.
///
/// \param[in] tensor A MSTensor object.
///
/// \return A vector container containing several strings.
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor); static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
/// \brief Destroy an object created by Clone, StringsToTensor, CreateRefTensor, CreateDevTensor or CreateTensor. Do
/// not use it to destroy MSTensor from other sources.
///
/// \param[in] tensor A MSTensor object.
static void DestroyTensorPtr(MSTensor *tensor) noexcept; static void DestroyTensorPtr(MSTensor *tensor) noexcept;
MSTensor(); MSTensor();
@ -85,19 +131,51 @@ class MS_API MSTensor {
explicit MSTensor(std::nullptr_t); explicit MSTensor(std::nullptr_t);
~MSTensor(); ~MSTensor();
/// \brief Obtains the name of the MSTensor.
///
/// \return The name of the MSTensor.
inline std::string Name() const; inline std::string Name() const;
/// \brief Obtains the data type of the MSTensor.
///
/// \return The data type of the MSTensor.
enum DataType DataType() const; enum DataType DataType() const;
/// \brief Obtains the shape of the MSTensor.
///
/// \return The shape of the MSTensor.
const std::vector<int64_t> &Shape() const; const std::vector<int64_t> &Shape() const;
/// \brief Obtains the number of elements of the MSTensor.
///
/// \return The number of elements of the MSTensor.
int64_t ElementNum() const; int64_t ElementNum() const;
/// \brief Obtains a shared pointer to the copy of data of the MSTensor. The data can be read on host.
///
/// \return A shared pointer to the copy of data of the MSTensor.
std::shared_ptr<const void> Data() const; std::shared_ptr<const void> Data() const;
/// \brief Obtains the pointer to the data of the MSTensor. If the MSTensor is a device tensor, the data cannot be
/// accessed directly on host.
///
/// \return A pointer to the data of the MSTensor.
void *MutableData(); void *MutableData();
/// \brief Obtains the length of the data of the MSTensor, in bytes.
///
/// \return The length of the data of the MSTensor, in bytes.
size_t DataSize() const; size_t DataSize() const;
/// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device.
///
/// \return The boolean value that indicates whether the memory of MSTensor is on device.
bool IsDevice() const; bool IsDevice() const;
/// \brief Gets a deep copy of the MSTensor, must be used in pair with DestroyTensorPtr.
///
/// \return A pointer points to a deep copy of the MSTensor.
MSTensor *Clone() const; MSTensor *Clone() const;
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
///
/// \return The boolean value that indicates whether the MSTensor is valid.
bool operator==(std::nullptr_t) const; bool operator==(std::nullptr_t) const;
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
///
/// \return The boolean value that indicates whether the MSTensor is valid.
bool operator!=(std::nullptr_t) const; bool operator!=(std::nullptr_t) const;
bool operator==(const MSTensor &tensor) const; bool operator==(const MSTensor &tensor) const;

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/l2loss_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void L2LossCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (const size_t &d : x_shape) {
tensor_size_ *= d;
}
}
template <typename T>
bool L2LossCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto result_addr = reinterpret_cast<T *>(outputs[0]->addr);
*result_addr = (T)0;
for (size_t i = 0; i < tensor_size_; i++) {
*result_addr += input_addr[i] * input_addr[i];
}
*result_addr = *result_addr / 2;
return true;
}
template <typename T>
void L2LossCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2LossCPUKernel needs 1 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2LossCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class L2LossCPUKernel : public CPUKernel {
public:
L2LossCPUKernel() = default;
~L2LossCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
size_t tensor_size_{1};
};
MS_REG_CPU_KERNEL_T(L2Loss, KernelAttr(), L2LossCPUKernel, float16);
MS_REG_CPU_KERNEL_T(L2Loss, KernelAttr(), L2LossCPUKernel, float);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2_LOSS_CPU_KERNEL_H_

View File

@ -16,15 +16,19 @@
#include "nnacl/common_func.h" #include "nnacl/common_func.h"
int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) {
return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3;
} }
int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { int OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2) {
return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3];
} }
int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); } int Offset4d(const int *shape, const int *dims) { return Offset(shape, dims[0], dims[1], dims[2], dims[3]); }
int Offset6d(const int *shape, const int *dims) {
return ((OffsetComm(shape, dims[0], dims[1], dims[2]) + dims[3]) * shape[4] + dims[4]) * shape[5];
}
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }

View File

@ -36,9 +36,10 @@ void ReluFp32C8(float *data, float *dst, int ele_num);
void Relu6Fp32C8(float *data, float *dst, int ele_num); void Relu6Fp32C8(float *data, float *dst, int ele_num);
#endif #endif
#endif #endif
int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3);
int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); int OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2);
int offset4d(const int *shape, const int *dims); int Offset4d(const int *shape, const int *dims);
int Offset6d(const int *shape, const int *dims);
static inline bool isAddOverflow(int32_t x, int32_t y) { static inline bool isAddOverflow(int32_t x, int32_t y) {
int32_t sum = x + y; int32_t sum = x + y;

View File

@ -19,16 +19,22 @@
void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape,
const int *paddings, const int tid, const int thread_num) { const int *paddings, const int tid, const int thread_num) {
int in[4], out[4]; int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS];
for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { for (in[0] = 0; in[0] < input_shape[0]; in[0]++) {
out[0] = in[0] + paddings[0]; out[0] = in[0] + paddings[0];
for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) {
out[1] = in[1] + paddings[2]; out[1] = in[1] + paddings[2];
for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { for (in[2] = 0; in[2] < input_shape[2]; in[2]++) {
out[2] = in[2] + paddings[4]; out[2] = in[2] + paddings[4];
float16_t *dst = output_data + offset(output_shape, out[0], out[1], out[2], paddings[6]); for (in[3] = 0; in[3] < input_shape[3]; in[3]++) {
const float16_t *src = input_data + offset(input_shape, in[0], in[1], in[2], 0); out[3] = in[3] + paddings[6];
memcpy(dst, src, input_shape[3] * sizeof(float16_t)); for (in[4] = 0; in[4] < input_shape[4]; in[4]++) {
out[4] = in[4] + paddings[8];
float16_t *dst = output_data + Offset6d(output_shape, out) + paddings[10];
const float16_t *src = input_data + Offset6d(input_shape, in);
memcpy(dst, src, input_shape[5] * sizeof(float16_t));
}
}
} }
} }
} }

View File

@ -23,16 +23,22 @@ void Pad(const float *input_data, float *output_data, const int *input_shape, co
if (thread_num == 0) { if (thread_num == 0) {
return; return;
} }
int in[4], out[4]; int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS];
for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { for (in[0] = 0; in[0] < input_shape[0]; in[0]++) {
out[0] = in[0] + paddings[0]; out[0] = in[0] + paddings[0];
for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) {
out[1] = in[1] + paddings[2]; out[1] = in[1] + paddings[2];
for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { for (in[2] = 0; in[2] < input_shape[2]; in[2]++) {
out[2] = in[2] + paddings[4]; out[2] = in[2] + paddings[4];
float *dst = output_data + offset(output_shape, out[0], out[1], out[2], paddings[6]); for (in[3] = 0; in[3] < input_shape[3]; in[3]++) {
const float *src = input_data + offset(input_shape, in[0], in[1], in[2], 0); out[3] = in[3] + paddings[6];
memcpy(dst, src, input_shape[3] * sizeof(float)); for (in[4] = 0; in[4] < input_shape[4]; in[4]++) {
out[4] = in[4] + paddings[8];
float *dst = output_data + Offset6d(output_shape, out) + paddings[10];
const float *src = input_data + Offset6d(input_shape, in);
memcpy(dst, src, input_shape[5] * sizeof(float));
}
}
} }
} }
} }
@ -57,8 +63,7 @@ int TransOut2InputDimIndex(int out_dim_index, int left_pad, int in_dim, int offs
int GetInputFlattenIndex(int out_flatten_index, const int *input_shape, const PadParameter *pad_param) { int GetInputFlattenIndex(int out_flatten_index, const int *input_shape, const PadParameter *pad_param) {
int in_flatten_index = 0; int in_flatten_index = 0;
int i; for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
for (i = 0; i < COMM_SHAPE_SIZE; ++i) {
int left_pad = pad_param->paddings_[i * 2]; int left_pad = pad_param->paddings_[i * 2];
NNACL_CHECK_ZERO_RETURN_ERR(pad_param->out_strides[i]) NNACL_CHECK_ZERO_RETURN_ERR(pad_param->out_strides[i])
int out_dim_index = out_flatten_index / pad_param->out_strides[i]; int out_dim_index = out_flatten_index / pad_param->out_strides[i];

View File

@ -510,8 +510,8 @@ int ResizeNearestNeighbor(const float *input_data, float *output_data, const int
} else { } else {
input_x = (int)(floorf(actual_x)); input_x = (int)(floorf(actual_x));
} }
int in_offset = offset(input_shape, batch, input_y, input_x, 0); int in_offset = Offset(input_shape, batch, input_y, input_x, 0);
int out_offset = offset(output_shape, batch, y, x, 0); int out_offset = Offset(output_shape, batch, y, x, 0);
memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float));
} }
} }

View File

@ -32,7 +32,7 @@ int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou
return NNACL_INFER_INVALID; return NNACL_INFER_INVALID;
} }
if (input->shape_size_ > 4) { if (input->shape_size_ > DEFAULT_PAD_NDIMS) {
return NNACL_INPUT_TENSOR_ERROR; return NNACL_INPUT_TENSOR_ERROR;
} }
const TensorC *paddings = inputs[1]; const TensorC *paddings = inputs[1];
@ -48,7 +48,7 @@ int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou
param->paddings_[i] = ((int *)paddings->data_)[i]; param->paddings_[i] = ((int *)paddings->data_)[i];
} }
int output_shape[MAX_SHAPE_SIZE] = {0}; int output_shape[DEFAULT_PAD_NDIMS] = {0};
size_t output_shape_size = 0; size_t output_shape_size = 0;
for (size_t i = 0; i < input->shape_size_; i++) { for (size_t i = 0; i < input->shape_size_; i++) {
int shape = input->shape_[i] + param->paddings_[2 * i] + param->paddings_[2 * i + 1]; int shape = input->shape_[i] + param->paddings_[2 * i] + param->paddings_[2 * i + 1];

View File

@ -24,8 +24,8 @@ int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dim
for (int n = 0; n < in_dims[0]; n++) { for (int n = 0; n < in_dims[0]; n++) {
for (int h = tid; h < in_dims[1]; h += thread_num) { for (int h = tid; h < in_dims[1]; h += thread_num) {
for (int w = 0; w < in_dims[2]; w++) { for (int w = 0; w < in_dims[2]; w++) {
const int8_t *in = in_data + offset(in_dims, n, h, w, 0); const int8_t *in = in_data + Offset(in_dims, n, h, w, 0);
int8_t *out = out_data + offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); int8_t *out = out_data + Offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]);
memcpy(out, in, copy_size * sizeof(int8_t)); memcpy(out, in, copy_size * sizeof(int8_t));
} }
} }

View File

@ -173,8 +173,8 @@ int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_dat
for (x = 0; x < output_shape[2]; x++) { for (x = 0; x < output_shape[2]; x++) {
int input_x = 0; int input_x = 0;
ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x);
int in_offset = offset(input_shape, batch, input_y, input_x, 0); int in_offset = Offset(input_shape, batch, input_y, input_x, 0);
int out_offset = offset(output_shape, batch, y, x, 0); int out_offset = Offset(output_shape, batch, y, x, 0);
memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(int8_t)); memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(int8_t));
} }
} }
@ -214,8 +214,8 @@ int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, con
int input_x = 0; int input_x = 0;
ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x);
for (c = 0; c < output_shape[3]; c++) { for (c = 0; c < output_shape[3]; c++) {
int in_offset = offset(input_shape, batch, input_y, input_x, c); int in_offset = Offset(input_shape, batch, input_y, input_x, c);
int out_offset = offset(output_shape, batch, y, x, c); int out_offset = Offset(output_shape, batch, y, x, c);
int32_t out_value = MultiplyByQuantizedMultiplier( int32_t out_value = MultiplyByQuantizedMultiplier(
input_data[in_offset] - quant_in->zp_, multiplier->multiplier_, input_data[in_offset] - quant_in->zp_, multiplier->multiplier_,

View File

@ -18,8 +18,8 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#define MAX_PAD_SIZE 8 #define MAX_PAD_SIZE 12
#define DEFAULT_PAD_NDIMS 4 #define DEFAULT_PAD_NDIMS 6
typedef struct PadQuantArg { typedef struct PadQuantArg {
QuantArg *in_quant_args_; QuantArg *in_quant_args_;
@ -30,13 +30,13 @@ typedef struct PadQuantArg {
typedef struct PadParameter { typedef struct PadParameter {
// Primitive parameter // Primitive parameter
OpParameter op_parameter_; OpParameter op_parameter_;
int paddings_[MAX_SHAPE_SIZE]; int paddings_[MAX_PAD_SIZE];
int pad_mode_; int pad_mode_;
float constant_value_; float constant_value_;
// shape correlative // shape correlative
int padding_length; int padding_length;
// other parameter // other parameter
int in_strides[COMM_SHAPE_SIZE]; int in_strides[DEFAULT_PAD_NDIMS];
int out_strides[DEFAULT_PAD_NDIMS]; int out_strides[DEFAULT_PAD_NDIMS];
int mirror_offset_; int mirror_offset_;
PadQuantArg pad_quant_arg_; PadQuantArg pad_quant_arg_;

View File

@ -197,7 +197,8 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE); auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode)) { if (!workspace_size_list_.empty() || hccl_data_type_list_.empty() || (!is_task_sink && mode == kGraphMode) ||
mode == kPynativeMode) {
return workspace_size_list_; return workspace_size_list_;
} }
workspace_size_list_.emplace_back( workspace_size_list_.emplace_back(

View File

@ -70,11 +70,11 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
MS_EXCEPTION_IF_NULL(fusion_op); MS_EXCEPTION_IF_NULL(fusion_op);
std::vector<std::string> input_names; std::vector<std::string> input_names;
for (uint8_t i = 0; i < inputs_list.size(); i++) { for (size_t i = 0; i < inputs_list.size(); i++) {
(void)input_names.emplace_back("input" + std::to_string(i)); (void)input_names.emplace_back("input" + std::to_string(i));
} }
std::vector<std::string> output_names; std::vector<std::string> output_names;
for (uint8_t i = 0; i < outputs_list.size(); i++) { for (size_t i = 0; i < outputs_list.size(); i++) {
(void)output_names.emplace_back("output" + std::to_string(i)); (void)output_names.emplace_back("output" + std::to_string(i));
} }

View File

@ -27,7 +27,7 @@ namespace opt {
namespace { namespace {
// insert tensormove for some cnode even if not a Ref cnode // insert tensormove for some cnode even if not a Ref cnode
const std::set<std::string> kNeedInsertTensorMoveOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, const std::set<std::string> kNeedInsertTensorMoveOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName,
kLambUpdateWithLROpName}; kLambUpdateWithLROpName, kGetNextOpName};
bool IsParameterOrValueNode(const AnfNodePtr &node) { bool IsParameterOrValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -86,9 +86,10 @@ bool InsertTensorMoveForHcclOp::NeedInsertTensorMove(const FuncGraphPtr &graph,
if (kernel_query_->IsTbeRef(input)) { if (kernel_query_->IsTbeRef(input)) {
return true; return true;
} }
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input, 0, true);
auto real_node = kernel_with_index.first;
// when input is some special cnodes // when input is some special cnodes
if (kNeedInsertTensorMoveOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertTensorMoveOpSet.end()) { if (kNeedInsertTensorMoveOpSet.find(AnfAlgo::GetCNodeName(real_node)) != kNeedInsertTensorMoveOpSet.end()) {
return true; return true;
} }

View File

@ -53,6 +53,15 @@ void SafeCheckFunction(const CNodePtr &cnode, const std::vector<int64_t> &reduce
} }
} }
void DynamicAttrUpdate(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
auto axis_attr = primitive->GetAttr(kAttrAxis);
AnfAlgo::SetNodeAttr(kAttrAxes, axis_attr, node);
AnfAlgo::EraseNodeAttr(kAttrAxis, node);
}
void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) { void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) {
auto axis = kernel::GetReduceAttrAxis(cnode); auto axis = kernel::GetReduceAttrAxis(cnode);
std::vector<int64_t> convert_axis; std::vector<int64_t> convert_axis;
@ -95,9 +104,15 @@ const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const A
} }
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
if (convert_map == kReduceConvertMap.end()) { if (convert_map == kReduceConvertMap.end()) {
if (AnfAlgo::IsDynamicShape(node)) {
DynamicAttrUpdate(node);
}
return nullptr; return nullptr;
} }
convert_map->second(node->cast<CNodePtr>()); convert_map->second(node->cast<CNodePtr>());
if (AnfAlgo::IsDynamicShape(node)) {
DynamicAttrUpdate(node);
}
return nullptr; return nullptr;
} }
} // namespace opt } // namespace opt

View File

@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf
return nullptr; return nullptr;
} }
auto bn_infer = CreateBNInfer(graph, batchnorm, node); auto bn_infer = CreateBNInfer(graph, batchnorm, node);
TransferDepend(batchnorm, graph, bn_infer); TransferDependOrUpdateState(batchnorm, graph, bn_infer);
return bn_infer; return bn_infer;
} }
} // namespace opt } // namespace opt

View File

@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
return nullptr; return nullptr;
} }
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
TransferDepend(batchnorm_grad, graph, bn_infer_grad); TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad);
return bn_infer_grad; return bn_infer_grad;
} }
} // namespace opt } // namespace opt

View File

@ -915,21 +915,34 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
return new_value_node; return new_value_node;
} }
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
// Find BatchNorm's output which is a Depend or UpdateState. // Find BatchNorm's output which is a Depend or UpdateState.
for (const auto &node_index : manager->node_users()[old_node]) { auto node_users = manager->node_users()[old_node];
for (const auto &node_index : node_users) {
AnfNodePtr output = node_index.first; AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
auto depend = output->cast<CNodePtr>(); auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend); MS_EXCEPTION_IF_NULL(output_cnode);
depend->set_input(index, new_node); auto inputs = output_cnode->inputs();
std::vector<AnfNodePtr> new_inputs{output_cnode->input(0)};
for (size_t i = 1; i < inputs.size(); i++) {
auto input = inputs[i];
if (input == old_node) {
new_inputs.emplace_back(new_node);
} else {
new_inputs.emplace_back(input);
}
}
auto new_output = graph->NewCNode(new_inputs);
new_output->set_abstract(output->abstract());
new_output->set_scope(output->scope());
manager->Replace(output, new_output);
} }
} }
} }

View File

@ -213,8 +213,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
// Create a new value node of func graph,not kernel graph // Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
// Transfer depend to the new node // Transfer depend or updatestate to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);

View File

@ -381,7 +381,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
} }
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode || if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_param)) { AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
tensor->set_device_address(device_address); tensor->set_device_address(device_address);
} }
if (kernel_graph->IsUpdatedParameter(input_param)) { if (kernel_graph->IsUpdatedParameter(input_param)) {
@ -523,30 +523,14 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) {
InitRuntimeResource(); InitRuntimeResource();
// multiple graph handle // multiple graph handle
if (graph_id == final_graph_id_) { if (graph_id == final_graph_id_) {
if (!graph->executable()) { MS_LOG(EXCEPTION) << "Unexpected graph id:" << graph_id << ", final_graph_id_:" << final_graph_id_;
return;
} }
SetFinalGraphSummaryFlag(graph);
// OptChildGraphs
auto graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
for (size_t i = 0; i < graph_order.size(); i++) {
if (!(graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START)) {
auto child_graph = GetGraph(graph_order[i]);
CompileChildGraph(child_graph);
}
}
SetSummaryNodes(graph.get());
// merge child graph
MergeGraphExecOrder();
} else {
auto single_graph = GetGraph(graph_id); auto single_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(single_graph); MS_EXCEPTION_IF_NULL(single_graph);
CompileChildGraph(single_graph); CompileChildGraph(single_graph);
// set the distinction label of single graph // set the distinction label of single graph
single_graph->set_stream_distinction_label(graph_id); single_graph->set_stream_distinction_label(graph_id);
single_graph->UpdateExecuteKernelStreamLabel(); single_graph->UpdateExecuteKernelStreamLabel();
}
// adjust execution order because merge child graph and other special operations // adjust execution order because merge child graph and other special operations
AdjustKernel(graph); AdjustKernel(graph);
#if ENABLE_CPU && ENABLE_D #if ENABLE_CPU && ENABLE_D
@ -863,9 +847,10 @@ void AscendSession::InitRuntimeResource() {
if (!runtime_instance->Init()) { if (!runtime_instance->Init()) {
MS_LOG(EXCEPTION) << "Kernel runtime init error."; MS_LOG(EXCEPTION) << "Kernel runtime init error.";
} }
auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID"); auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case. // get actual rank id if it's distribution training case.
rank_id_ = GetRankId(); rank_id_ = GetRankId();
} }

View File

@ -114,12 +114,12 @@ void GPUSession::Init(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id); ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
if (collective_inited) { if (collective_inited) {
rank_id_ = GetRankId();
if (collective_handle_ != nullptr) { if (collective_handle_ != nullptr) {
auto init_nccl_comm_funcptr = auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm")); reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr); MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)(); (*init_nccl_comm_funcptr)();
rank_id_ = GetRankId();
} }
} }

View File

@ -470,7 +470,7 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
} }
} }
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const { void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) { if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map " MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
"flag but got the node :" "flag but got the node :"
@ -493,7 +493,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
node->set_kernel_info(kernel_info); node->set_kernel_info(kernel_info);
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) { if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>()); ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
} }
#if defined(__APPLE__) #if defined(__APPLE__)
std::vector<int> feature_map_input_indexs; std::vector<int> feature_map_input_indexs;
@ -1362,7 +1362,9 @@ void KernelGraph::SetOptimizerFlag() {
continue; continue;
} }
auto param = real_node->cast<ParameterPtr>(); auto param = real_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(param)) { auto abstract = param->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractRef>()) {
has_optimizer_ = true; has_optimizer_ = true;
(void)updated_parameters_.insert(param); (void)updated_parameters_.insert(param);
} }

View File

@ -111,7 +111,7 @@ class KernelGraph : public FuncGraph {
CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr); CNodePtr NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode = nullptr);
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode); CNodePtr NewCNode(const CNodePtr &cnode);
void ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const; void ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const;
ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr); ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract); ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value); ValueNodePtr NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value);

View File

@ -446,6 +446,45 @@ void UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> &root_graph) {
} }
return; return;
} }
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
constexpr auto kReturnInputIdx = 1;
auto return_node = node->cast<CNodePtr>();
graph->set_return(return_node);
auto graph_output = return_node->input(kReturnInputIdx);
MS_EXCEPTION_IF_NULL(graph_output);
// If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
// match this pattern because that pass begin with output node but return node. So we add transform value tuple
// to make_tuple here.
if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
}
}
}
bool NoPartialInPartialGraph(const AnfNodePtr &partial_node) {
MS_EXCEPTION_IF_NULL(partial_node);
auto partial_cnode = partial_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
MS_EXCEPTION_IF_NULL(partial_graph);
auto graph_nodes = TopoSort(partial_graph->get_return());
for (auto &node : graph_nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) {
return false;
}
if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(kAnfPrimitiveIndex))) {
return false;
}
}
return true;
}
} // namespace } // namespace
GraphId SessionBasic::graph_sum_ = 0; GraphId SessionBasic::graph_sum_ = 0;
@ -1463,9 +1502,7 @@ bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph
new_cnode->set_fullname_with_scope(fullname); new_cnode->set_fullname_with_scope(fullname);
new_cnode->set_scope(cnode->scope()); new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode); graph->FrontBackendlMapAdd(node, new_cnode);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { SetReturnNode(new_cnode, graph);
graph->set_return(new_cnode);
}
return true; return true;
} }
@ -1957,7 +1994,8 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
if (internal_output) { if (internal_output) {
auto users = ExtendNodeUsers(front_func_graph_manager, front_node); auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
for (auto &user : users) { for (auto &user : users) {
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice) { if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
NoPartialInPartialGraph(user)) {
auto partial_target = AddPartialParametersMap(front_func_graph_manager, user); auto partial_target = AddPartialParametersMap(front_func_graph_manager, user);
if (partial_target != kNoTarget && partial_target != kernel_target) { if (partial_target != kNoTarget && partial_target != kernel_target) {
unique_target = false; unique_target = false;
@ -2659,6 +2697,7 @@ uint32_t GetRankId() {
world_group = kNcclWorldGroup; world_group = kNcclWorldGroup;
} else { } else {
MS_LOG(ERROR) << "Invalid backend: " << backend; MS_LOG(ERROR) << "Invalid backend: " << backend;
return rank_id;
} }
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id."; MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -596,7 +596,8 @@ void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_fu
std::ofstream fout(realpath.value()); std::ofstream fout(realpath.value());
std::ostringstream buffer; std::ostringstream buffer;
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"; MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
@ -638,7 +639,8 @@ void DumpIRForRDR(const std::string &filename, const FuncGraphPtr &graph, bool d
std::ofstream fout(realpath.value()); std::ofstream fout(realpath.value());
std::ostringstream buffer; std::ostringstream buffer;
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"; MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }

View File

@ -606,7 +606,8 @@ void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPt
std::ofstream ofs(filename); std::ofstream ofs(filename);
if (!ofs.is_open()) { if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; MS_LOG(ERROR) << "Open file '" << filename << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }

View File

@ -274,7 +274,8 @@ bool Common::SaveStringToFile(const std::string filename, const std::string stri
ofs.open(real_path.value()); ofs.open(real_path.value());
if (!ofs.is_open()) { if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open dump file '" << real_path.value() << "' failed!"; MS_LOG(ERROR) << "Open dump file '" << real_path.value() << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return false; return false;
} }
ofs << string_info << std::endl; ofs << string_info << std::endl;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -92,7 +92,8 @@ void DumpJsonParser::Parse() {
std::ifstream json_file(dump_config_file.value()); std::ifstream json_file(dump_config_file.value());
if (!json_file.is_open()) { if (!json_file.is_open()) {
MS_LOG(EXCEPTION) << "Dump file:" << dump_config_file.value() << " open failed."; MS_LOG(EXCEPTION) << "Dump file:" << dump_config_file.value() << " open failed."
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
} }
nlohmann::json j; nlohmann::json j;
@ -176,7 +177,7 @@ void DumpJsonParser::CopyMSCfgJsonToDir(uint32_t rank_id) {
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
ms_info["device_target"] = context->get_param<std::string>(MS_CTX_DEVICE_TARGET); ms_info["device_target"] = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
ms_info["ms_version"] = "1.3.0"; ms_info["ms_version"] = "1.4.0";
const std::string file_path = realpath.value(); const std::string file_path = realpath.value();
ChangeFileMode(file_path, S_IWUSR); ChangeFileMode(file_path, S_IWUSR);
std::ofstream json_create(file_path); std::ofstream json_create(file_path);
@ -204,7 +205,8 @@ bool DumpJsonParser::DumpToFile(const std::string &filename, const void *data, s
ChangeFileMode(file_path, S_IWUSR); ChangeFileMode(file_path, S_IWUSR);
std::ofstream fd(file_path, std::ios::out | std::ios::trunc | std::ios::binary); std::ofstream fd(file_path, std::ios::out | std::ios::trunc | std::ios::binary);
if (!fd.is_open()) { if (!fd.is_open()) {
MS_LOG(ERROR) << "Open file " << file_path << " failed."; MS_LOG(ERROR) << "Open file " << file_path << " failed."
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return false; return false;
} }
std::string npy_header = GenerateNpyHeader(shape, type); std::string npy_header = GenerateNpyHeader(shape, type);
@ -349,7 +351,7 @@ void DumpJsonParser::ParseIteration(const nlohmann::json &content) {
MS_LOG(EXCEPTION) << "iteration only supports digits, {'-', '|'}, or just \"all\" but got: " << iteration_; MS_LOG(EXCEPTION) << "iteration only supports digits, {'-', '|'}, or just \"all\" but got: " << iteration_;
} }
} else if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) { } else if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
MS_LOG(WARNING) << "Dump not enabled. "; MS_LOG(WARNING) << "Dump is not enabled. ";
} else { } else {
MS_LOG(EXCEPTION) << "Dump Json Parse Failed. Async or E2E should be enabled. "; MS_LOG(EXCEPTION) << "Dump Json Parse Failed. Async or E2E should be enabled. ";
} }
@ -484,14 +486,14 @@ void DumpJsonParser::JudgeDumpEnabled() {
} }
if (!async_dump_enabled_ && !e2e_dump_enabled_) { if (!async_dump_enabled_ && !e2e_dump_enabled_) {
MS_LOG(WARNING) << "Dump json parse failed. Dump not enabled"; MS_LOG(WARNING) << "Dump json parse failed. Dump is not enabled";
} }
if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) { if (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kCPUDevice) {
auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID); auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
if (support_devices_.find(device_id) == support_devices_.end()) { if (support_devices_.find(device_id) == support_devices_.end()) {
async_dump_enabled_ = false; async_dump_enabled_ = false;
e2e_dump_enabled_ = false; e2e_dump_enabled_ = false;
MS_LOG(WARNING) << "Dump not enabled. device_id:" << device_id << " not support"; MS_LOG(WARNING) << "Dump is not enabled. device_id:" << device_id << " not support";
} }
} }
JsonConfigToString(); JsonConfigToString();
@ -532,9 +534,10 @@ std::string DumpJsonParser::GetOpOverflowBinPath(uint32_t graph_id) const {
bin_path.append("rank_"); bin_path.append("rank_");
uint32_t rank_id = 0; uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID"); auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case. // get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) { if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id."; MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -26,6 +26,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "pybind11/embed.h" #include "pybind11/embed.h"
#include "pybind11/stl.h"
#ifdef ONLINE_DBG_MODE #ifdef ONLINE_DBG_MODE
#include "debug/common.h" #include "debug/common.h"
#include "debug/debugger/debugger.h" #include "debug/debugger/debugger.h"
@ -80,45 +81,57 @@ void DebugServices::RemoveWatchpoint(unsigned int id) {
std::unique_ptr<ITensorSummary> GetSummaryPtr(const std::shared_ptr<TensorData> &tensor, std::unique_ptr<ITensorSummary> GetSummaryPtr(const std::shared_ptr<TensorData> &tensor,
void *const previous_tensor_ptr, uint32_t num_elements, void *const previous_tensor_ptr, uint32_t num_elements,
int tensor_dtype) { uint32_t prev_num_elements, int tensor_dtype) {
switch (tensor_dtype) { switch (tensor_dtype) {
case DbgDataType::DT_UINT8: { case DbgDataType::DT_UINT8: {
return std::make_unique<TensorSummary<uint8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<uint8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_INT8: { case DbgDataType::DT_INT8: {
return std::make_unique<TensorSummary<int8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<int8_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_UINT16: { case DbgDataType::DT_UINT16: {
return std::make_unique<TensorSummary<uint16_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<uint16_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_INT16: { case DbgDataType::DT_INT16: {
return std::make_unique<TensorSummary<int16_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<int16_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_UINT32: { case DbgDataType::DT_UINT32: {
return std::make_unique<TensorSummary<uint32_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<uint32_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_INT32: case DbgDataType::DT_INT32:
case DbgDataType::DT_BASE_INT: { case DbgDataType::DT_BASE_INT: {
return std::make_unique<TensorSummary<int32_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<int32_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_UINT64: { case DbgDataType::DT_UINT64: {
return std::make_unique<TensorSummary<uint64_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<uint64_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_INT64: { case DbgDataType::DT_INT64: {
return std::make_unique<TensorSummary<int64_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<int64_t>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_FLOAT16: { case DbgDataType::DT_FLOAT16: {
return std::make_unique<TensorSummary<float16>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<float16>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_FLOAT32: case DbgDataType::DT_FLOAT32:
case DbgDataType::DT_BASE_FLOAT: { case DbgDataType::DT_BASE_FLOAT: {
return std::make_unique<TensorSummary<float>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<float>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_FLOAT64: { case DbgDataType::DT_FLOAT64: {
return std::make_unique<TensorSummary<double>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<double>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
case DbgDataType::DT_BOOL: { case DbgDataType::DT_BOOL: {
return std::make_unique<TensorSummary<bool>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements); return std::make_unique<TensorSummary<bool>>(tensor->GetDataPtr(), previous_tensor_ptr, num_elements,
prev_num_elements);
} }
default: default:
MS_LOG(INFO) << "Unsupported tensor type"; MS_LOG(INFO) << "Unsupported tensor type";
@ -128,7 +141,8 @@ std::unique_ptr<ITensorSummary> GetSummaryPtr(const std::shared_ptr<TensorData>
} }
#ifdef OFFLINE_DBG_MODE #ifdef OFFLINE_DBG_MODE
void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bool previous_iter_tensor_needed) { void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bool previous_iter_tensor_needed,
uint32_t *prev_num_elements) {
void *previous_tensor_ptr = nullptr; void *previous_tensor_ptr = nullptr;
std::shared_ptr<TensorData> tensor_prev; std::shared_ptr<TensorData> tensor_prev;
if (previous_iter_tensor_needed && tensor->GetIteration() >= 1) { if (previous_iter_tensor_needed && tensor->GetIteration() >= 1) {
@ -151,6 +165,7 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bo
tensor_prev.reset(); tensor_prev.reset();
} else { } else {
previous_tensor_ptr = tensor_prev->GetDataPtr(); previous_tensor_ptr = tensor_prev->GetDataPtr();
*prev_num_elements = tensor_prev->GetNumElements();
} }
} }
return previous_tensor_ptr; return previous_tensor_ptr;
@ -158,9 +173,17 @@ void *DebugServices::GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bo
#endif #endif
void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
const std::string &tensor_name, const std::string &tensor_name_no_slot, const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
bool *previous_iter_tensor_needed, std::string *const qualified_tensor_name, std::string *const qualified_tensor_name,
std::vector<watchpoint_t> *const watchpoints_to_check) { std::vector<watchpoint_t> *const watchpoints_to_check) {
if (tensor == nullptr) {
MS_LOG(DEBUG) << "tensor is nullptr.";
return;
}
const auto tensor_name = tensor->GetName();
const auto tensor_name_no_slot = tensor_name.substr(0, tensor_name.find_first_of(':'));
const auto tensor_device_id = tensor->GetDeviceId();
const auto tensor_root_graph_id = tensor->GetRootGraphId();
for (auto w_table_item : watchpoint_table) { for (auto w_table_item : watchpoint_table) {
auto wp = std::get<1>(w_table_item); auto wp = std::get<1>(w_table_item);
// check ONLY init conditions on initial suspended state. // check ONLY init conditions on initial suspended state.
@ -178,7 +201,7 @@ void DebugServices::AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end,
wp_lock_.unlock(); wp_lock_.unlock();
if (wp_cache_hit) continue; if (wp_cache_hit) continue;
} }
std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot); std::string found = wp.FindQualifiedTensorName(tensor_name_no_slot, tensor_device_id, tensor_root_graph_id);
if (!found.empty()) { if (!found.empty()) {
*qualified_tensor_name = found; *qualified_tensor_name = found;
watchpoints_to_check->push_back(w_table_item.second); watchpoints_to_check->push_back(w_table_item.second);
@ -238,21 +261,26 @@ void DebugServices::CheckWatchpointsForTensor(
bool previous_iter_tensor_needed = false; bool previous_iter_tensor_needed = false;
// Add do nothing line in case offline debug is off, prevent unused var warning // Add do nothing line in case offline debug is off, prevent unused var warning
(void)previous_iter_tensor_needed; (void)previous_iter_tensor_needed;
AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor_name, tensor_name_no_slot, AddWatchPointsToCheck(init_dbg_suspend, step_end, recheck, tensor, &previous_iter_tensor_needed,
&previous_iter_tensor_needed, &qualified_tensor_name, &watchpoints_to_check); &qualified_tensor_name, &watchpoints_to_check);
// no wp set on current tensor // no wp set on current tensor
if (watchpoints_to_check.empty()) continue; if (watchpoints_to_check.empty()) continue;
uint32_t num_elements = tensor->GetNumElements(); uint32_t num_elements = tensor->GetNumElements();
uint32_t prev_num_elements = 0;
void *previous_tensor_ptr = nullptr;
#ifdef OFFLINE_DBG_MODE #ifdef OFFLINE_DBG_MODE
void *previous_tensor_ptr = GetPrevTensor(tensor, previous_iter_tensor_needed); previous_tensor_ptr = GetPrevTensor(tensor, previous_iter_tensor_needed, &prev_num_elements);
#else #else
void *previous_tensor_ptr = std::shared_ptr<TensorData> prev_tensor_data = tensor_loader_->GetPrevTensor(tensor_name);
tensor_loader_->GetPrevTensor(tensor_name) ? tensor_loader_->GetPrevTensor(tensor_name)->GetDataPtr() : nullptr; if (prev_tensor_data) {
previous_tensor_ptr = prev_tensor_data->GetDataPtr();
prev_num_elements = prev_tensor_data->GetNumElements();
}
#endif #endif
std::unique_ptr<ITensorSummary> base_summary_ptr; std::unique_ptr<ITensorSummary> base_summary_ptr;
if (!(watchpoints_to_check.size() == 1 && watchpoints_to_check[0].condition.type == IS_OVERFLOW)) { if (!(watchpoints_to_check.size() == 1 && watchpoints_to_check[0].condition.type == IS_OVERFLOW)) {
base_summary_ptr = GetSummaryPtr(tensor, previous_tensor_ptr, num_elements, tensor_dtype); base_summary_ptr = GetSummaryPtr(tensor, previous_tensor_ptr, num_elements, prev_num_elements, tensor_dtype);
if (base_summary_ptr != nullptr) { if (base_summary_ptr != nullptr) {
base_summary_ptr->SummarizeTensor(watchpoints_to_check); base_summary_ptr->SummarizeTensor(watchpoints_to_check);
} }
@ -399,7 +427,8 @@ void DebugServices::ReadTensorFromNpy(const std::string &file_name, std::string
MS_LOG(INFO) << "Reading in file: " << file_path; MS_LOG(INFO) << "Reading in file: " << file_path;
infile.open(file_path.c_str(), std::ios::ate | std::ios::binary | std::ios::in); infile.open(file_path.c_str(), std::ios::ate | std::ios::binary | std::ios::in);
if (!infile.is_open()) { if (!infile.is_open()) {
MS_LOG(ERROR) << "Failed to open file (In ReadTensorFromNpy) " << file_path; MS_LOG(ERROR) << "Failed to open file (In ReadTensorFromNpy) " << file_path << " Errno:" << errno
<< " ErrInfo:" << strerror(errno);
return; return;
} }
uint64_t file_size = infile.tellg(); uint64_t file_size = infile.tellg();
@ -439,6 +468,7 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
std::string file_format = "npy"; std::string file_format = "npy";
for (auto const &d : dir_to_files_map) { for (auto const &d : dir_to_files_map) {
std::vector<std::string> files_to_convert_in_dir; std::vector<std::string> files_to_convert_in_dir;
std::vector<std::string> files_after_convert_in_dir;
std::string dump_key = d.first; std::string dump_key = d.first;
for (auto const &file_name : d.second) { for (auto const &file_name : d.second) {
bool already_converted = false; bool already_converted = false;
@ -457,26 +487,19 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
} }
if (!already_converted) { if (!already_converted) {
files_to_convert_in_dir.push_back(dump_key + "/" + file_name); files_to_convert_in_dir.push_back(dump_key + "/" + file_name);
files_after_convert_in_dir.push_back(dump_key + "/" + file_name_without_scope);
} }
} }
std::ostringstream input_file_o; MS_LOG(INFO) << "Number of files to convert: " << files_to_convert_in_dir.size();
const char *const delim = " "; if (!files_to_convert_in_dir.empty()) {
std::copy(files_to_convert_in_dir.begin(), files_to_convert_in_dir.end(),
std::ostream_iterator<std::string>(input_file_o, delim));
std::string input_files = input_file_o.str();
MS_LOG(INFO) << "Ops to convert: " << input_files;
if (input_files != "") {
// Look for the installation path to the conver_async package. If not found, throw exception and terminate the // Look for the installation path to the conver_async package. If not found, throw exception and terminate the
// later task. // later task.
try { try {
auto pkg = pybind11::module::import("mindspore.offline_debug.convert_async"); auto pkg = pybind11::module::import("mindspore.offline_debug.convert_async");
std::string convert_pkg_path = pkg.attr("__file__").cast<std::string>(); auto convert_obj = pkg.attr("AsyncDumpConverter")(pybind11::cast(files_to_convert_in_dir), dump_key);
MS_LOG(INFO) << "The file for converting async dump data is in " << convert_pkg_path; (void)convert_obj.attr("convert_files")();
std::string convert_command = "python " + convert_pkg_path + " -out " + dump_key + " -t " + file_format +
" -d " + dump_key + " -f NCHW -l " + input_files;
(void)(system(convert_command.c_str()) + 1);
} catch (pybind11::error_already_set &e) { } catch (pybind11::error_already_set &e) {
MS_LOG(EXCEPTION) << "Can't find package mindspore.offline_debug.convert_async"; MS_LOG(EXCEPTION) << "Failed to convert async dump data: " << e.what();
} }
DIR *d_handle = opendir(dump_key.c_str()); DIR *d_handle = opendir(dump_key.c_str());
@ -485,7 +508,7 @@ void DebugServices::ConvertToHostFormat(const std::map<std::string, std::vector<
while ((dir = readdir(d_handle)) != NULL) { while ((dir = readdir(d_handle)) != NULL) {
if (dir->d_type == DT_REG) { if (dir->d_type == DT_REG) {
std::string candidate = dir->d_name; std::string candidate = dir->d_name;
for (const std::string &file_to_find : files_to_convert_in_dir) { for (const std::string &file_to_find : files_after_convert_in_dir) {
std::string file_n = file_to_find.substr(file_to_find.find_last_of("\\/") + 1); std::string file_n = file_to_find.substr(file_to_find.find_last_of("\\/") + 1);
if (candidate.find(file_n) != std::string::npos && candidate.rfind(file_format) != std::string::npos) { if (candidate.find(file_n) != std::string::npos && candidate.rfind(file_format) != std::string::npos) {
// we found a converted file for this op // we found a converted file for this op
@ -736,7 +759,7 @@ void DebugServices::ReadDumpedTensor(std::vector<std::string> backend_name, std:
std::string slot_string_to_check; std::string slot_string_to_check;
std::string prefix_dump_file_name; std::string prefix_dump_file_name;
SetPrefixToCheck(&prefix_dump_file_name, &slot_string_to_check, &dump_style_kernel_name, slot[i], is_output[i]); SetPrefixToCheck(&prefix_dump_file_name, &slot_string_to_check, &dump_style_kernel_name, slot[i], is_output[i]);
std::string prefix_dump_to_check = dump_style_kernel_name; std::string prefix_dump_to_check = dump_style_kernel_name + '.';
GetNodeNameWithoutScope(&prefix_dump_to_check); GetNodeNameWithoutScope(&prefix_dump_to_check);
std::string specific_dump_dir = dump_dir + "/rank_" + std::to_string(device_id[i]) + "/" + net_name + "/" + std::string specific_dump_dir = dump_dir + "/rank_" + std::to_string(device_id[i]) + "/" + net_name + "/" +
@ -1069,12 +1092,12 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int
#ifdef ONLINE_DBG_MODE #ifdef ONLINE_DBG_MODE
auto debugger = Debugger::GetInstance(); auto debugger = Debugger::GetInstance();
overflow_bin_path = DumpJsonParser::GetInstance().GetOpOverflowBinPath(debugger->GetGraphPtr()->graph_id()); overflow_bin_path = DumpJsonParser::GetInstance().GetOpOverflowBinPath(debugger->GetGraphPtr()->graph_id());
auto realpath = Common::GetRealPath(overflow_bin_path); std::string check_overflow_bin_path = RealPath(overflow_bin_path);
if (!realpath.has_value()) { if (check_overflow_bin_path.empty()) {
MS_LOG(ERROR) << "Get real path failed for overflow_bin_path."; MS_LOG(WARNING) << "Get real path failed for overflow_bin_path.";
return false; return false;
} }
overflow_bin_path = realpath.value(); overflow_bin_path = check_overflow_bin_path;
#else #else
overflow_bin_path = dump_dir + "/rank_" + std::to_string(device_id) + "/" + net_name + "/" + overflow_bin_path = dump_dir + "/rank_" + std::to_string(device_id) + "/" + net_name + "/" +
std::to_string(root_graph_id) + "/" + IterationString(iteration) + "/"; std::to_string(root_graph_id) + "/" + IterationString(iteration) + "/";
@ -1108,8 +1131,8 @@ bool DebugServices::CheckOpOverflow(std::string node_name_to_find, unsigned int
std::ifstream infile; std::ifstream infile;
infile.open(file_path.c_str(), std::ios::ate | std::ios::binary | std::ios::in); infile.open(file_path.c_str(), std::ios::ate | std::ios::binary | std::ios::in);
if (!infile.is_open()) { if (!infile.is_open()) {
MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name; MS_LOG(ERROR) << "Failed to open overflow bin file " << file_name << " Errno:" << errno
MS_LOG(ERROR) << "Error: " << strerror(errno); << " ErrInfo:" << strerror(errno);
continue; continue;
} }
@ -1256,7 +1279,7 @@ std::string DebugServices::RealPath(const std::string &input_path) {
MS_LOG(EXCEPTION) << "The length of file name : " << file_name.length() << " exceeds limit: " << NAME_MAX; MS_LOG(EXCEPTION) << "The length of file name : " << file_name.length() << " exceeds limit: " << NAME_MAX;
} }
if (realpath(prefix_path.c_str(), real_path) == nullptr) { if (realpath(prefix_path.c_str(), real_path) == nullptr) {
MS_LOG(ERROR) << "The dir " << prefix_path << " does not exist."; MS_LOG(WARNING) << "The dir " << prefix_path << " does not exist.";
return ""; return "";
} }

View File

@ -125,16 +125,30 @@ class DebugServices {
std::vector<parameter_t> parameter_list; std::vector<parameter_t> parameter_list;
size_t location = 0; size_t location = 0;
std::string FindQualifiedTensorName(const std::string &tensor_name) const { std::string FindQualifiedTensorName(const std::string &tensor_name, unsigned const int &tensor_device_id,
unsigned const int &tensor_root_graph_id) const {
std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':')); std::string node_name = tensor_name.substr(0, tensor_name.find_first_of(':'));
int indx = 0;
for (auto check_node : check_node_list) { for (auto check_node : check_node_list) {
std::string w_name = std::get<0>(check_node); std::string w_name = std::get<0>(check_node);
bool w_type = std::get<1>(check_node); bool w_type = std::get<1>(check_node);
auto found = w_name.find_last_of('/'); auto found = w_name.find_last_of('/');
if (found != std::string::npos && w_name.substr(found + 1) == tensor_name) return w_name; bool check_tensor_name = found != std::string::npos && w_name.substr(found + 1) == tensor_name;
if ((w_type && (tensor_name.find(w_name) == location || w_name == "*")) || (!w_type && node_name == w_name)) { bool check_node_name = (w_type && (node_name == w_name || w_name == "*")) || (!w_type && node_name == w_name);
if (check_tensor_name || check_node_name) {
// online debugger only support single card
if (check_node_device_list.empty()) {
return w_name; return w_name;
} }
auto device_vec = std::get<1>(check_node_device_list[indx]);
auto root_graph_vec = std::get<1>(check_node_graph_list[indx]);
auto iter1 = std::find(device_vec.begin(), device_vec.end(), tensor_device_id);
auto iter2 = std::find(root_graph_vec.begin(), root_graph_vec.end(), tensor_root_graph_id);
if (iter1 != device_vec.end() && iter2 != root_graph_vec.end()) {
return w_name;
}
}
indx++;
} }
return {}; return {};
} }
@ -214,8 +228,8 @@ class DebugServices {
const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr, const bool step_end, const bool recheck, std::vector<unsigned int> *device_id = nullptr,
std::vector<unsigned int> *root_graph_id = nullptr); std::vector<unsigned int> *root_graph_id = nullptr);
void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck, const std::string &tensor_name, void AddWatchPointsToCheck(bool init_dbg_suspend, bool step_end, bool recheck,
const std::string &tensor_name_no_slot, bool *previous_iter_tensor_needed, const std::shared_ptr<TensorData> &tensor, bool *previous_iter_tensor_needed,
std::string *qualified_tensor_name, std::vector<watchpoint_t> *watchpoints_to_check); std::string *qualified_tensor_name, std::vector<watchpoint_t> *watchpoints_to_check);
#ifdef OFFLINE_DBG_MODE #ifdef OFFLINE_DBG_MODE
@ -236,7 +250,8 @@ class DebugServices {
std::vector<std::shared_ptr<TensorData>> ReadNeededDumpedTensors(unsigned int iteration, std::vector<std::shared_ptr<TensorData>> ReadNeededDumpedTensors(unsigned int iteration,
std::vector<std::string> *async_file_pool); std::vector<std::string> *async_file_pool);
void *GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bool previous_iter_tensor_needed); void *GetPrevTensor(const std::shared_ptr<TensorData> &tensor, bool previous_iter_tensor_needed,
uint32_t *prev_num_elements);
void ReadTensorFromNpy(const std::string &file_name, std::string *tensor_type, std::size_t *size, void ReadTensorFromNpy(const std::string &file_name, std::string *tensor_type, std::size_t *size,
std::vector<int64_t> *shape, std::vector<char> **data_buffer); std::vector<int64_t> *shape, std::vector<char> **data_buffer);

View File

@ -113,7 +113,7 @@ void Debugger::Init(const uint32_t device_id, const std::string device_target) {
device_id_ = device_id; device_id_ = device_id;
MS_LOG(INFO) << "Debugger got device_target: " << device_target; MS_LOG(INFO) << "Debugger got device_target: " << device_target;
device_target_ = device_target; device_target_ = device_target;
version_ = "1.3.0"; version_ = "1.4.0";
} }
bool IsTypeDebuggerSupported(TypeId type) { bool IsTypeDebuggerSupported(TypeId type) {
@ -147,8 +147,22 @@ void Debugger::EnableDebugger() {
} }
if (debugger_enabled_) { if (debugger_enabled_) {
std::string host = "localhost"; // configure grpc host
std::string env_host_str = common::GetEnv("MS_DEBUGGER_HOST");
std::string host;
if (!env_host_str.empty()) {
if (CheckIp(env_host_str)) {
MS_LOG(INFO) << "Getenv MS_DEBUGGER_HOST: " << env_host_str;
host = env_host_str;
} else {
debugger_enabled_ = false;
MS_EXCEPTION(ValueError) << "Environment variable MS_DEBUGGER_HOST isn't a valid IP address. "
"Please set environment variable MS_DEBUGGER_HOST=x.x.x.x to a valid IP";
}
} else {
MS_LOG(INFO) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost";
host = "localhost";
}
// configure grpc port // configure grpc port
std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT"); std::string env_port_str = common::GetEnv("MS_DEBUGGER_PORT");
std::string port; std::string port;
@ -1120,6 +1134,17 @@ bool Debugger::CheckPort(const std::string &port) const {
return true; return true;
} }
bool Debugger::CheckIp(const std::string &host) const {
std::regex reg_ip(
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"
"[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
"[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])"
"[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])");
std::smatch smat;
std::string host_str = host;
return std::regex_match(host_str, smat, reg_ip);
}
uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); } uint32_t Debugger::GetFirstRunGraphId() const { return rungraph_id_list_.front(); }
void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) { void Debugger::LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index) {

View File

@ -235,6 +235,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// Check if the port is valid // Check if the port is valid
bool CheckPort(const std::string &port) const; bool CheckPort(const std::string &port) const;
// Check if the IP is valid
bool CheckIp(const std::string &host) const;
void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index); void LoadSingleAnfnode(const AnfNodePtr &anf_node, const size_t output_index);
// class members // class members

View File

@ -48,7 +48,7 @@ DbgServices::~DbgServices() {
std::string DbgServices::GetVersion() { std::string DbgServices::GetVersion() {
MS_LOG(INFO) << "get version is called"; MS_LOG(INFO) << "get version is called";
return "1.3.0"; return "1.4.0";
} }
int32_t DbgServices::Initialize(std::string net_name, std::string dump_folder_path, bool is_sync_mode) { int32_t DbgServices::Initialize(std::string net_name, std::string dump_folder_path, bool is_sync_mode) {

View File

@ -573,7 +573,8 @@ void DumpIRProtoWithSrcInfo(const FuncGraphPtr &func_graph, const std::string &s
// write to pb file // write to pb file
std::ofstream ofs(realpath.value()); std::ofstream ofs(realpath.value());
if (!ofs.is_open()) { if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!"; MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
ofs << graph_proto; ofs << graph_proto;

View File

@ -91,10 +91,12 @@ double VarianceAndMeanCalculator::GetVariance() const {
double VarianceAndMeanCalculator::GetStandardDeviation() { return sqrt(GetVariance()); } double VarianceAndMeanCalculator::GetStandardDeviation() { return sqrt(GetVariance()); }
template <typename T> template <typename T>
TensorSummary<T>::TensorSummary(void *current_tensor_ptr, void *const previous_tensor_ptr, uint32_t num_elements) TensorSummary<T>::TensorSummary(void *current_tensor_ptr, void *const previous_tensor_ptr, uint32_t num_elements,
uint32_t prev_num_elements)
: current_tensor_ptr(reinterpret_cast<T *>(current_tensor_ptr)), : current_tensor_ptr(reinterpret_cast<T *>(current_tensor_ptr)),
prev_tensor_ptr(reinterpret_cast<T *>(previous_tensor_ptr)), prev_tensor_ptr(reinterpret_cast<T *>(previous_tensor_ptr)),
num_elements(num_elements), num_elements(num_elements),
prev_num_elements_(prev_num_elements),
min(std::numeric_limits<double>::max()), min(std::numeric_limits<double>::max()),
max(std::numeric_limits<double>::lowest()), max(std::numeric_limits<double>::lowest()),
inf_count(0), inf_count(0),
@ -108,8 +110,14 @@ void TensorSummary<T>::SummarizeTensor(const std::vector<DebugServices::watchpoi
InitCalculators(wps); InitCalculators(wps);
for (size_t i = 0; i < num_elements; ++i) { for (size_t i = 0; i < num_elements; ++i) {
auto current_value = static_cast<double>(current_tensor_ptr[i]); auto current_value = static_cast<double>(current_tensor_ptr[i]);
double previous_value = double previous_value = std::numeric_limits<double>::quiet_NaN();
prev_tensor_ptr ? static_cast<double>(prev_tensor_ptr[i]) : std::numeric_limits<double>::quiet_NaN(); if (prev_tensor_ptr) {
if (num_elements == prev_num_elements_) {
previous_value = static_cast<double>(prev_tensor_ptr[i]);
} else {
MS_LOG(DEBUG) << "Current and previous tensor are not the same size.";
}
}
inf_count += std::isinf(current_value); inf_count += std::isinf(current_value);
nan_count += std::isnan(current_value); nan_count += std::isnan(current_value);
zero_count += (current_value == 0); zero_count += (current_value == 0);

View File

@ -99,7 +99,7 @@ class TensorSummary : public ITensorSummary {
public: public:
TensorSummary() = default; TensorSummary() = default;
~TensorSummary() override = default; ~TensorSummary() override = default;
TensorSummary(void *, void *, uint32_t); TensorSummary(void *, void *, uint32_t, uint32_t);
void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &) override; void SummarizeTensor(const std::vector<DebugServices::watchpoint_t> &) override;
// returns hit, error_code, parameter_list // returns hit, error_code, parameter_list
std::tuple<bool, int, std::vector<DebugServices::parameter_t>> IsWatchpointHit(DebugServices::watchpoint_t) override; std::tuple<bool, int, std::vector<DebugServices::parameter_t>> IsWatchpointHit(DebugServices::watchpoint_t) override;
@ -108,6 +108,7 @@ class TensorSummary : public ITensorSummary {
T *current_tensor_ptr; T *current_tensor_ptr;
T *prev_tensor_ptr; T *prev_tensor_ptr;
uint32_t num_elements; uint32_t num_elements;
uint32_t prev_num_elements_;
double min; double min;
double max; double max;
uint32_t inf_count; uint32_t inf_count;

View File

@ -555,7 +555,8 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
// write to pb file // write to pb file
std::ofstream ofs(file_path); std::ofstream ofs(file_path);
if (!ofs.is_open()) { if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << file_path << "' failed!"; MS_LOG(ERROR) << "Open file '" << file_path << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
ofs << GetFuncGraphProtoString(func_graph); ofs << GetFuncGraphProtoString(func_graph);

View File

@ -122,7 +122,8 @@ void EnvConfigParser::ParseFromFile() {
std::ifstream json_file(config_file_); std::ifstream json_file(config_file_);
if (!json_file.is_open()) { if (!json_file.is_open()) {
MS_LOG(WARNING) << "Env config file:" << config_file_ << " open failed." MS_LOG(WARNING) << "Env config file:" << config_file_ << " open failed."
<< " Please check the config file '" << config_file_ << "' set by 'env_config_path' in context."; << " Please check the config file '" << config_file_ << "' set by 'env_config_path' in context."
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }

View File

@ -348,7 +348,8 @@ bool AnalyzeFailExporter::ExportFuncGraph(const std::string &filename, const Tra
} }
std::ofstream ofs(filename); std::ofstream ofs(filename);
if (!ofs.is_open()) { if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; MS_LOG(ERROR) << "Open file '" << filename << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return false; return false;
} }

View File

@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <utility> #include <utility>
@ -128,6 +129,178 @@ class GetItemTransformACrossGraph {
private: private:
std::unordered_map<FuncGraphPtr, std::unordered_map<int64_t, FuncGraphPtr>> cache_; std::unordered_map<FuncGraphPtr, std::unordered_map<int64_t, FuncGraphPtr>> cache_;
}; };
bool HasMoreJ(const OptimizerPtr &optimizer) {
bool more_j = false;
auto res = optimizer->resource();
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(res);
if (resource_ptr != nullptr) {
const auto &manager = optimizer->manager();
MS_EXCEPTION_IF_NULL(manager);
more_j = manager->func_graph_j_total(resource_ptr->func_graph());
}
return more_j;
}
bool IsOutputShrinkable(const AnfNodePtr &output) {
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
return true;
}
if (GetValueNode<ValueTuplePtr>(output)) {
return true;
}
return false;
}
size_t GetOutputSize(const AnfNodePtr &output) {
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
const auto &output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
return output_cnode->size() - 1;
}
const auto &value_tuple = GetValueNode<ValueTuplePtr>(output);
if (value_tuple == nullptr) {
MS_LOG(EXCEPTION) << "fg output is not MakeTuple or ValueTuple, but: " << output->DebugString();
}
return value_tuple->size();
}
struct TpCNodeAndIndex {
// CNode {TupleGetItem, call, index}
CNodePtr tp_cnode;
int64_t index;
};
int64_t UpdateUserNodeIndex(const CNodePtr &fg_call_cnode, const int64_t current_index,
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg_call_cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
int64_t new_index = current_index;
auto txn = manager->Transact();
for (int64_t i = 0; i < SizeToLong(tp_cnodes_and_index.size()); ++i) {
const auto &cnode_and_index = tp_cnodes_and_index[i];
if (cnode_and_index.index != i) {
constexpr auto kInputIndex = 2;
txn.SetEdge(cnode_and_index.tp_cnode, kInputIndex, NewValueNode(i));
}
if (cnode_and_index.index == current_index) {
new_index = i;
}
}
txn.Commit();
return new_index;
}
AbstractBasePtr ShrinkAbstract(const AbstractBasePtr &original_abstract,
const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
if (original_abstract != nullptr && original_abstract->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = original_abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
const auto &abs_tuple_elements = abs_tuple->elements();
const int64_t before_shrink_tuple_size = SizeToLong(abs_tuple_elements.size());
AbstractBasePtrList shrunk_abstract_elements;
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(),
std::back_inserter(shrunk_abstract_elements),
[abs_tuple_elements, before_shrink_tuple_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_tuple_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", abstract tuple size: " << before_shrink_tuple_size;
}
return abs_tuple_elements[node_and_index.index];
});
return std::make_shared<abstract::AbstractTuple>(shrunk_abstract_elements);
}
return nullptr;
}
FuncGraphPtr ShrinkUnsedOutput(const FuncGraphPtr &fg, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("tp_use"));
auto new_fg_output = new_fg->output();
AnfNodePtr shrunk_output = nullptr;
int64_t before_shrink_inputs_size = 0;
if (IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
// Shrink output;
auto new_fg_output_cnode = new_fg_output->cast<CNodePtr>();
const auto &new_fg_output_inputs = new_fg_output_cnode->inputs();
constexpr auto kMinimalSize = 2;
if (new_fg_output_inputs.size() <= kMinimalSize) {
MS_LOG(EXCEPTION) << "New fg output should at least 2 elements, but: " << new_fg_output->DebugString();
}
before_shrink_inputs_size = SizeToLong(new_fg_output_inputs.size() - 1);
AnfNodePtrList shrunk_inputs{NewValueNode({prim::kPrimMakeTuple})};
// Bypass maketuple primitive in new_fg_output_inputs;
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs),
[new_fg_output, new_fg_output_inputs, before_shrink_inputs_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_inputs_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", output: " << new_fg_output->DebugString();
}
return new_fg_output_inputs[node_and_index.index + 1];
});
shrunk_output = new_fg->NewCNode(shrunk_inputs);
} else {
auto value_tuple = GetValueNode<ValueTuplePtr>(new_fg_output);
if (value_tuple == nullptr) {
MS_LOG(EXCEPTION) << "New fg output is not MakeTuple or ValueTuple, but " << new_fg_output->DebugString();
}
ValuePtrList shrunk_inputs;
before_shrink_inputs_size = value_tuple->size();
std::transform(tp_cnodes_and_index.cbegin(), tp_cnodes_and_index.cend(), std::back_inserter(shrunk_inputs),
[new_fg_output, value_tuple, before_shrink_inputs_size](const auto &node_and_index) {
if (node_and_index.index >= before_shrink_inputs_size) {
MS_LOG(EXCEPTION) << "index should less than inputs size, index: " << node_and_index.index
<< ", output: " << new_fg_output->DebugString();
}
return (*value_tuple)[node_and_index.index];
});
shrunk_output = NewValueNode(std::make_shared<ValueTuple>(shrunk_inputs));
}
auto shrunk_abstract = ShrinkAbstract(new_fg_output->abstract(), tp_cnodes_and_index);
MS_EXCEPTION_IF_NULL(shrunk_abstract);
shrunk_output->set_abstract(shrunk_abstract);
new_fg->set_output(shrunk_output);
MS_LOG(DEBUG) << "Partly item used; original size: " << before_shrink_inputs_size
<< ", new size: " << tp_cnodes_and_index.size() << ", fg: " << fg->ToString() << ", new graph"
<< new_fg->ToString();
return new_fg;
}
struct FuncGraphIntVectorPairHasher {
std::size_t Int64VectorHash(const std::vector<int64_t> &int_vector) const {
std::size_t hash_value = 0;
constexpr auto kMaxElementsNum = 4;
for (size_t i = 0; (i < int_vector.size()) && (i < kMaxElementsNum); ++i) {
hash_value = hash_combine(hash_value, std::hash<int64_t>{}(int_vector[i]));
}
return hash_value;
}
std::size_t operator()(const std::pair<FuncGraphPtr, std::vector<int64_t>> &p) const {
auto h1 = std::hash<FuncGraphPtr>{}(p.first);
auto h2 = Int64VectorHash(p.second);
return hash_combine(h1, h2);
}
};
bool ShouldTransform(const AnfNodePtr &node, const std::vector<TpCNodeAndIndex> &tp_cnodes_and_index) {
if (node->abstract() && node->abstract()->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = *(node->abstract()->cast<abstract::AbstractTuplePtr>());
if (tp_cnodes_and_index[0].index == 0 && abs_tuple.size() > 0) {
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
return true;
}
}
// fprop_fg will return MakeTuple(xx, bprop_fg).
if (tp_cnodes_and_index.size() > 1 && tp_cnodes_and_index[1].index == 1 && abs_tuple.size() > 1 &&
abs_tuple[1]->isa<abstract::AbstractFunction>()) {
return true;
}
}
return false;
}
} // namespace internal } // namespace internal
// {prim::kPrimTupleGetItem, {G, Xs}, C} // {prim::kPrimTupleGetItem, {G, Xs}, C}
@ -136,23 +309,52 @@ class IncorporateGetitem : public AnfVisitor {
IncorporateGetitem() : getitem_transform_() {} IncorporateGetitem() : getitem_transform_() {}
~IncorporateGetitem() override = default; ~IncorporateGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr || fg_->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) { fg_->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
return nullptr; return nullptr;
} }
// This node had been substituted.
if (processed_nodes_.find(fg_call_cnode_) != processed_nodes_.end()) {
MS_LOG(DEBUG) << "fg call with same cnode is already replaced, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
return nullptr;
}
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
bool output_is_shrinkable = internal::IsOutputShrinkable(fg_->output());
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
auto fg_call_cnode_users_counter = MultipleUse(fg_call_cnode_, fg_, &tp_cnodes_and_index);
bool multiple_use = (tp_cnodes_and_index.size() > 1);
if (output_is_shrinkable && multiple_use && (tp_cnodes_and_index.size() == fg_call_cnode_users_counter)) {
if (!internal::ShouldTransform(fg_call_cnode_, tp_cnodes_and_index) && !internal::HasMoreJ(optimizer)) {
MS_LOG(DEBUG) << "No more j and multiple use, will shrink, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
const auto output_size = internal::GetOutputSize(fg_->output());
if (fg_call_cnode_users_counter == output_size) {
processed_nodes_.emplace(fg_call_cnode_);
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
return nullptr;
}
auto new_node = ShrinkFuncGraphOutput(node, tp_cnodes_and_index);
if (new_node != nullptr) {
return new_node;
}
}
}
MS_LOG(DEBUG) << "Cannot shrink, transform_getitem, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
auto new_fg = getitem_transform_(node, fg_, idx_); auto new_fg = getitem_transform_(node, fg_, idx_);
MS_LOG(DEBUG) << "Original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString();
(void)args_.insert(args_.begin(), NewValueNode(new_fg)); (void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_node = node->func_graph()->NewCNode(args_); auto new_node = node->func_graph()->NewCNode(args_);
// Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace // Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace
// UpdateState{s, {G, Xs}} with UpdateState{s, new_node}; // UpdateState{s, {G, Xs}} with UpdateState{s, new_node};
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users_map = manager->node_users(); auto &node_users_map = manager->node_users();
auto it = node_users_map.find(fg_cnode_); auto it = node_users_map.find(fg_call_cnode_);
if (it != node_users_map.end()) { if (it != node_users_map.end()) {
AnfNodePtr update_state_node = nullptr; AnfNodePtr update_state_node = nullptr;
auto &node_users = it->second; auto &node_users = it->second;
@ -166,7 +368,7 @@ class IncorporateGetitem : public AnfVisitor {
if (update_state_node != nullptr) { if (update_state_node != nullptr) {
auto update_state_cnode = update_state_node->cast<CNodePtr>(); auto update_state_cnode = update_state_node->cast<CNodePtr>();
// double check; // double check;
if (update_state_cnode->input(2) == fg_cnode_) { if (update_state_cnode->input(2) == fg_call_cnode_) {
MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(2) MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(2)
<< ", input 2 with: " << new_node->DebugString(); << ", input 2 with: " << new_node->DebugString();
manager->SetEdge(update_state_cnode, 2, new_node); manager->SetEdge(update_state_cnode, 2, new_node);
@ -177,12 +379,98 @@ class IncorporateGetitem : public AnfVisitor {
return new_node; return new_node;
} }
size_t MultipleUse(const CNodePtr &fg_call, const FuncGraphPtr &fg,
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const {
const auto &manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &cnode_and_index_vector = *cnodes_and_index;
std::set<int64_t> index_set;
std::size_t total_usage = 0;
const auto &node_users_map = manager->node_users();
const auto &it = node_users_map.find(fg_call);
if (it == node_users_map.end()) {
return 0;
}
const auto &node_users = it->second;
for (const auto &user : node_users) {
if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
const auto &cnode = user.first->cast<CNodePtr>();
if (cnode->input(2)->isa<ValueNode>()) {
auto idx = GetValue<int64_t>(cnode->input(2)->cast<ValueNodePtr>()->value());
cnode_and_index_vector.push_back({cnode, idx});
index_set.insert(idx);
total_usage++;
} else {
MS_LOG(EXCEPTION) << "tuple_getitem index is not valuenode, but: " << user.first->DebugString();
}
} else {
MS_LOG(DEBUG) << "fg_call usre is not tuple_getitem, user: " << user.first->DebugString();
}
}
if (index_set.size() != total_usage) {
MS_LOG(DEBUG) << "some index usage is duplicated, total_usage: " << total_usage;
MS_LOG(DEBUG) << "index_set:";
for (auto idx : index_set) {
MS_LOG(DEBUG) << " " << idx;
}
}
// sort by index;
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(),
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; });
return node_users.size();
}
AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node,
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = fg_->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> index_vector;
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector),
[](const auto &cnode_and_index) { return cnode_and_index.index; });
auto iter = processed_fgs_.find(std::make_pair(fg_, index_vector));
if (iter != processed_fgs_.end()) {
MS_LOG(DEBUG) << "fg is already processed, just update caller index, node: " << node->DebugString()
<< ", fg_call: " << fg_call_cnode_->DebugString();
MS_LOG(DEBUG) << "original fg: " << fg_->ToString() << ", processed_fg: " << iter->second->ToString();
processed_nodes_.emplace(fg_call_cnode_);
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(iter->second));
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
fg_call_cnode_->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
const auto new_fg = internal::ShrinkUnsedOutput(fg_, tp_cnodes_and_index);
if (new_fg != nullptr) {
MS_LOG(DEBUG) << "fg output is shrunk, original fg: " << fg_->ToString() << ", new fg: " << new_fg->ToString();
processed_nodes_.emplace(fg_call_cnode_);
processed_fgs_.emplace(std::make_pair(fg_, index_vector), new_fg);
manager->SetEdge(fg_call_cnode_, 0, NewValueNode(new_fg));
auto shrunk_abstract = internal::ShrinkAbstract(fg_call_cnode_->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
fg_call_cnode_->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(fg_call_cnode_, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fg_call_cnode_, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString()
<< ", switch_call: " << fg_call_cnode_->DebugString();
return nullptr;
}
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) { if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) {
return; return;
} }
fg_cnode_ = cnode; fg_call_cnode_ = cnode;
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
fg_ = GetValueNode<FuncGraphPtr>(inputs[0]); fg_ = GetValueNode<FuncGraphPtr>(inputs[0]);
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));
@ -193,15 +481,19 @@ class IncorporateGetitem : public AnfVisitor {
void Reset() { void Reset() {
idx_ = -1; idx_ = -1;
fg_ = nullptr; fg_ = nullptr;
fg_cnode_ = nullptr; fg_call_cnode_ = nullptr;
args_.clear(); args_.clear();
} }
private: private:
int64_t idx_{-1}; int64_t idx_{-1};
FuncGraphPtr fg_{nullptr}; FuncGraphPtr fg_{nullptr};
AnfNodePtr fg_cnode_{nullptr}; CNodePtr fg_call_cnode_{nullptr};
std::vector<AnfNodePtr> args_{}; std::vector<AnfNodePtr> args_{};
std::set<AnfNodePtr> processed_nodes_;
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr,
internal::FuncGraphIntVectorPairHasher>
processed_fgs_;
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
@ -298,7 +590,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
IncorporateGetitemSwitch() : getitem_transform_() {} IncorporateGetitemSwitch() : getitem_transform_() {}
~IncorporateGetitemSwitch() override = default; ~IncorporateGetitemSwitch() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
Reset(); Reset();
is_in_get_ = true; is_in_get_ = true;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int64Imm>})(node);
@ -316,33 +608,57 @@ class IncorporateGetitemSwitch : public AnfVisitor {
if (g2_ == nullptr) { if (g2_ == nullptr) {
return nullptr; return nullptr;
} }
auto tuple_getitem = node->cast<CNodePtr>(); if (processed_nodes_.find(switch_) != processed_nodes_.end()) {
MS_EXCEPTION_IF_NULL(tuple_getitem); MS_LOG(DEBUG) << "fg in switch node has been replaced. node: " << node->DebugString()
bool has_env_type = false; << ", switch: " << switch_->DebugString();
if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>());
// eliminate (envinstance, value1, value2, ...) built by bprop func_graph()
if (abs_tuple.size() >= 1) {
// Value maybe kAnyValue, so check the type track;
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
has_env_type = true;
}
}
// eliminate (value, bprop_func) built by fprop func_graph
if (abs_tuple.size() >= 2) {
if (abs_tuple[1]->isa<abstract::AbstractFunction>()) {
has_env_type = true;
}
}
}
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_) && !has_env_type) {
return nullptr; return nullptr;
} }
bool g1_output_is_shrinkable = internal::IsOutputShrinkable(g1_->output());
bool g2_output_is_shrinkable = internal::IsOutputShrinkable(g2_->output());
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
const auto &switch_call = tuple_getitem->input(1);
MS_EXCEPTION_IF_NULL(switch_call);
const auto &switch_call_cnode = switch_call->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_call_cnode);
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
std::vector<internal::TpCNodeAndIndex> tp_cnodes_and_index;
auto switch_call_users_counter = MultipleUseOfSwitch(switch_call, fg, &tp_cnodes_and_index);
bool multiple_use = (tp_cnodes_and_index.size() > 1);
if (g1_output_is_shrinkable && g2_output_is_shrinkable && multiple_use &&
(tp_cnodes_and_index.size() == switch_call_users_counter)) {
if (!internal::HasMoreJ(optimizer) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_) && !internal::ShouldTransform(switch_call, tp_cnodes_and_index)) {
MS_LOG(DEBUG) << "No more j, will shrink. Node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
const auto g1_output_size = internal::GetOutputSize(g1_->output());
const auto g2_output_size = internal::GetOutputSize(g2_->output());
if (g1_output_size != g2_output_size) {
MS_LOG(EXCEPTION) << "output of g1 and g2 should have same tuple size, but g1 output: "
<< g1_->output()->DebugString() << ", g2 output: " << g2_->output()->DebugString();
}
if (switch_call_users_counter == g1_output_size) {
processed_nodes_.emplace(switch_call);
MS_LOG(DEBUG) << "All elements in output is used, no need to transform, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
return nullptr;
}
auto new_node = ShrinkFuncGraphOutput(node, switch_call_cnode, tp_cnodes_and_index);
if (new_node != nullptr) {
return new_node;
}
}
}
MS_LOG(DEBUG) << "Cannot shrink output, transform_getitem_switch, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
auto new_g1 = getitem_transform_(node, g1_, idx_); auto new_g1 = getitem_transform_(node, g1_, idx_);
auto new_g2 = getitem_transform_(node, g2_, idx_); auto new_g2 = getitem_transform_(node, g2_, idx_);
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString();
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
(void)args_.insert(args_.begin(), sw_node); (void)args_.insert(args_.begin(), sw_node);
@ -350,7 +666,60 @@ class IncorporateGetitemSwitch : public AnfVisitor {
new_node->set_abstract(node->abstract()); new_node->set_abstract(node->abstract());
return new_node; return new_node;
} }
AnfNodePtr ShrinkFuncGraphOutput(const AnfNodePtr &node, const CNodePtr &switch_call_cnode,
const std::vector<internal::TpCNodeAndIndex> &tp_cnodes_and_index) {
const auto &manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto switch_cnode = switch_->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_cnode);
std::vector<int64_t> index_vector;
(void)std::transform(tp_cnodes_and_index.begin(), tp_cnodes_and_index.end(), std::back_inserter(index_vector),
[](const auto &cnode_and_index) { return cnode_and_index.index; });
const auto &iter1 = processed_fgs_.find(std::make_pair(g1_, index_vector));
const auto &iter2 = processed_fgs_.find(std::make_pair(g2_, index_vector));
if (iter1 != processed_fgs_.end() && iter2 != processed_fgs_.end()) {
MS_LOG(DEBUG) << "fg output had been processed, no need to transform, node: " << node->DebugString()
<< ", switch: " << switch_->DebugString();
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << iter1->second->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << iter2->second->ToString();
processed_nodes_.emplace(switch_);
manager->SetEdge(switch_cnode, 2, NewValueNode(iter1->second));
manager->SetEdge(switch_cnode, 3, NewValueNode(iter2->second));
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
switch_call_cnode->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
const auto &new_g1 = internal::ShrinkUnsedOutput(g1_, tp_cnodes_and_index);
const auto &new_g2 = internal::ShrinkUnsedOutput(g2_, tp_cnodes_and_index);
if (new_g1 != nullptr && new_g2 != nullptr) {
MS_LOG(DEBUG) << "Shrink output. node: " << node->DebugString() << ", switch: " << switch_->DebugString();
MS_LOG(DEBUG) << "Original fg1: " << g1_->ToString() << ", new_fg1: " << new_g1->ToString();
MS_LOG(DEBUG) << "Original fg2: " << g2_->ToString() << ", new_fg2: " << new_g2->ToString();
processed_nodes_.emplace(switch_);
processed_fgs_.emplace(std::make_pair(g1_, index_vector), new_g1);
processed_fgs_.emplace(std::make_pair(g2_, index_vector), new_g2);
manager->SetEdge(switch_cnode, 2, NewValueNode(new_g1));
manager->SetEdge(switch_cnode, 3, NewValueNode(new_g2));
auto shrunk_abstract = internal::ShrinkAbstract(switch_call_cnode->abstract(), tp_cnodes_and_index);
if (shrunk_abstract != nullptr) {
switch_call_cnode->set_abstract(shrunk_abstract);
}
auto new_idx = internal::UpdateUserNodeIndex(switch_call_cnode, idx_, tp_cnodes_and_index);
auto new_node =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), switch_call_cnode, NewValueNode(new_idx)});
new_node->set_abstract(node->abstract());
return new_node;
}
MS_LOG(DEBUG) << "Shrink failed. node: " << node->DebugString()
<< ", switch_call: " << switch_call_cnode->DebugString();
return nullptr;
}
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
if (is_in_switch_ && x_ == nullptr) { if (is_in_switch_ && x_ == nullptr) {
x_ = node; x_ = node;
@ -393,22 +762,51 @@ class IncorporateGetitemSwitch : public AnfVisitor {
} }
private: private:
bool MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg) const { size_t MultipleUseOfSwitch(const AnfNodePtr &switch_call, const FuncGraphPtr &fg,
std::vector<internal::TpCNodeAndIndex> *cnodes_and_index) const {
auto switch_call_cnode = switch_call->cast<CNodePtr>(); auto switch_call_cnode = switch_call->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_call_cnode); MS_EXCEPTION_IF_NULL(switch_call_cnode);
auto manager = fg->manager(); auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto &cnode_and_index_vector = *cnodes_and_index;
std::set<int64_t> index_set;
std::size_t total_usage = 0;
auto &node_users_map = manager->node_users(); auto &node_users_map = manager->node_users();
auto it = node_users_map.find(switch_call); auto it = node_users_map.find(switch_call);
if (it == node_users_map.end()) { if (it == node_users_map.end()) {
return false; return 0;
} }
auto &node_users = it->second; auto &node_users = it->second;
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s // If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) { for (auto user : node_users) {
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem); if (IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
}); auto cnode = user.first->cast<CNodePtr>();
return tuple_getitem_num > 1; constexpr auto kInputIndex = 2;
if (cnode->input(kInputIndex)->isa<ValueNode>()) {
const auto &idx_node = cnode->input(kInputIndex)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(idx_node);
auto idx = GetValue<int64_t>(idx_node->value());
cnode_and_index_vector.push_back({cnode, idx});
index_set.insert(idx);
total_usage++;
} else {
MS_LOG(EXCEPTION) << "Tuple_getitem index is not valuenode, but: " << user.first->DebugString(2);
}
} else {
MS_LOG(DEBUG) << "switch_call user is not tuple_getitem, user: " << user.first->DebugString(2);
}
}
if (index_set.size() != total_usage) {
MS_LOG(DEBUG) << "some index is duplicated, total_usage: " << total_usage;
MS_LOG(DEBUG) << "index_set: ";
for (auto idx : index_set) {
MS_LOG(DEBUG) << " " << idx;
}
}
// sort by index;
std::sort(cnode_and_index_vector.begin(), cnode_and_index_vector.end(),
[](const auto &tp1, const auto &tp2) { return tp1.index < tp2.index; });
return node_users.size();
} }
static bool inline ExistEnvNode(const FuncGraphPtr &fg) { static bool inline ExistEnvNode(const FuncGraphPtr &fg) {
@ -441,6 +839,10 @@ class IncorporateGetitemSwitch : public AnfVisitor {
FuncGraphPtr g1_{nullptr}, g2_{nullptr}; FuncGraphPtr g1_{nullptr}, g2_{nullptr};
bool is_in_get_{false}, is_in_switch_{false}; bool is_in_get_{false}, is_in_switch_{false};
std::vector<AnfNodePtr> args_{}; std::vector<AnfNodePtr> args_{};
std::set<AnfNodePtr> processed_nodes_;
std::unordered_map<std::pair<FuncGraphPtr, std::vector<int64_t>>, FuncGraphPtr,
internal::FuncGraphIntVectorPairHasher>
processed_fgs_;
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };

View File

@ -64,8 +64,14 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
} }
for (auto &shape : inputs_shape_) { for (auto &shape : inputs_shape_) {
Shape temp; Shape temp;
if (!shape.empty()) {
if (shape[0] % total_dev_num == 0) {
temp.emplace_back(SizeToLong(total_dev_num)); temp.emplace_back(SizeToLong(total_dev_num));
} else {
temp.emplace_back(1);
}
(void)temp.insert(temp.end(), shape.size() - 1, 1); (void)temp.insert(temp.end(), shape.size() - 1, 1);
}
strategy.push_back(temp); strategy.push_back(temp);
} }
sp = std::make_shared<Strategy>(stage_id, strategy); sp = std::make_shared<Strategy>(stage_id, strategy);

View File

@ -2038,7 +2038,12 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
if (shape_list[0][i].empty()) { if (shape_list[0][i].empty()) {
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
} }
Dimensions input_strategy = {dev_num}; Dimensions input_strategy;
if (!shape_list[0][i].empty() && shape_list[0][i][0] % dev_num == 0) {
input_strategy.push_back(dev_num);
} else if (!shape_list[0][i].empty()) {
input_strategy.push_back(1);
}
for (size_t j = 1; j < shape_list[0][i].size(); j++) { for (size_t j = 1; j < shape_list[0][i].size(); j++) {
input_strategy.push_back(1); input_strategy.push_back(1);
} }
@ -3222,12 +3227,9 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
} }
} }
Status ParallelInit() { CommInfo GetCommInfo() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
int64_t device_num = ParallelContext::GetInstance()->device_num(); int64_t device_num = ParallelContext::GetInstance()->device_num();
int64_t global_rank = ParallelContext::GetInstance()->global_rank(); int64_t global_rank = ParallelContext::GetInstance()->global_rank();
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
@ -3240,15 +3242,8 @@ Status ParallelInit() {
world_group = NCCL_WORLD_GROUP; world_group = NCCL_WORLD_GROUP;
communication_backend = NCCL_BACKEND; communication_backend = NCCL_BACKEND;
} else { } else {
MS_LOG(ERROR) << "Invalid communication backend: " << backend; MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
return FAILED;
} }
if (split_stage_num <= 0) {
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
return FAILED;
}
uint32_t world_rank_size = 0; uint32_t world_rank_size = 0;
if (!ParallelContext::GetInstance()->device_num_is_set()) { if (!ParallelContext::GetInstance()->device_num_is_set()) {
if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
@ -3266,7 +3261,21 @@ Status ParallelInit() {
global_rank = UintToInt(rank_id); global_rank = UintToInt(rank_id);
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
} }
CommInfo comm_info{device_num, global_rank, world_group, communication_backend};
return comm_info;
}
Status ParallelInit() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
if (split_stage_num <= 0) {
MS_LOG(ERROR) << "Invalid stage num " << split_stage_num << ", expected a positive stage number";
return FAILED;
}
auto comm_info = GetCommInfo();
int64_t device_num = comm_info.device_num;
int64_t global_rank = comm_info.global_rank;
if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) { if ((device_num <= 0) || (device_num > MAX_DEVICE_NUM)) {
MS_LOG(ERROR) << "Invalid device num " << device_num; MS_LOG(ERROR) << "Invalid device num " << device_num;
return FAILED; return FAILED;
@ -3293,13 +3302,14 @@ Status ParallelInit() {
return FAILED; return FAILED;
} }
if (!InitDevice(device_num, global_rank, communication_backend, stages)) { if (!InitDevice(device_num, global_rank, comm_info.communication_backend, stages)) {
MS_LOG(ERROR) << "Init device failed"; MS_LOG(ERROR) << "Init device failed";
return FAILED; return FAILED;
} }
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
<< ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean() << ", communication_backend: " << comm_info.communication_backend
<< ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
return SUCCESS; return SUCCESS;
@ -3714,7 +3724,13 @@ void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr
bool IsInsertVirtualOutput(const FuncGraphPtr &root) { bool IsInsertVirtualOutput(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty()); auto comm_info = GetCommInfo();
int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
int32_t per_stage_device_num = comm_info.device_num / split_stage_num;
int32_t current_stage = comm_info.global_rank / per_stage_device_num;
MS_LOG(INFO) << "The current stage is: " << current_stage;
return (!root->has_flag(TRAINING) && ParallelContext::GetInstance()->dataset_strategy().empty() &&
current_stage == split_stage_num - 1);
} }
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {

View File

@ -47,6 +47,13 @@ struct LossNodeInfo {
CNodePtr loss_node = nullptr; CNodePtr loss_node = nullptr;
}; };
struct CommInfo {
int64_t device_num = 1;
int64_t global_rank = 0;
std::string world_group;
std::string communication_backend;
};
struct ParameterSliceInfo { struct ParameterSliceInfo {
Shape slice_shape; Shape slice_shape;
RankList group_ranks; RankList group_ranks;
@ -178,6 +185,8 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
std::string MirrorOpName(); std::string MirrorOpName();
CommInfo GetCommInfo();
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages); void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h" #include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h" #include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h" #include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
@ -67,6 +68,17 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER(
AdjustGammaOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AdjustGammaOperation, TensorOperation, std::shared_ptr<vision::AdjustGammaOperation>>(
*m, "AdjustGammaOperation")
.def(py::init([](float gamma, float gain) {
auto ajust_gamma = std::make_shared<vision::AdjustGammaOperation>(gamma, gain);
THROW_IF_ERROR(ajust_gamma->ValidateParams());
return ajust_gamma;
}));
}));
PYBIND_REGISTER( PYBIND_REGISTER(
AutoContrastOperation, 1, ([](const py::module *m) { AutoContrastOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::AutoContrastOperation, TensorOperation, std::shared_ptr<vision::AutoContrastOperation>>( (void)py::class_<vision::AutoContrastOperation, TensorOperation, std::shared_ptr<vision::AutoContrastOperation>>(

View File

@ -21,6 +21,7 @@
#endif #endif
#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/affine_ir.h" #include "minddata/dataset/kernels/ir/vision/affine_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h" #include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h" #include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
@ -118,6 +119,19 @@ std::shared_ptr<TensorOperation> Affine::Parse() {
} }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
// AdjustGamma Transform Operation.
struct AdjustGamma::Data {
Data(float gamma, float gain) : gamma_(gamma), gain_(gain) {}
float gamma_;
float gain_;
};
AdjustGamma::AdjustGamma(float gamma, float gain) : data_(std::make_shared<Data>(gamma, gain)) {}
std::shared_ptr<TensorOperation> AdjustGamma::Parse() {
return std::make_shared<AdjustGammaOperation>(data_->gamma_, data_->gain_);
}
// AutoContrast Transform Operation. // AutoContrast Transform Operation.
struct AutoContrast::Data { struct AutoContrast::Data {
Data(float cutoff, const std::vector<uint32_t> &ignore) : cutoff_(cutoff), ignore_(ignore) {} Data(float cutoff, const std::vector<uint32_t> &ignore) : cutoff_(cutoff), ignore_(ignore) {}

View File

@ -57,7 +57,31 @@ class AutoContrast final : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief Apply a given image transform on a random selection of bounding box regions of a given image. /// \brief AdjustGamma TensorTransform.
/// \notes Apply gamma correction on input image.
class AdjustGamma final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] gamma Non negative real number, which makes the output image pixel value
/// exponential in relation to the input image pixel value.
/// \param[in] gain The constant multiplier.
explicit AdjustGamma(float gamma, float gain = 1);
/// \brief Destructor.
~AdjustGamma() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief BoundingBoxAugment TensorTransform.
/// \note Apply a given image transform on a random selection of bounding box regions of a given image.
class BoundingBoxAugment final : public TensorTransform { class BoundingBoxAugment final : public TensorTransform {
public: public:
/// \brief Constructor. /// \brief Constructor.

View File

@ -6,6 +6,7 @@ if(ENABLE_ACL)
add_subdirectory(dvpp) add_subdirectory(dvpp)
endif() endif()
add_library(kernels-image OBJECT add_library(kernels-image OBJECT
adjust_gamma_op.cc
affine_op.cc affine_op.cc
auto_contrast_op.cc auto_contrast_op.cc
bounding_box.cc bounding_box.cc

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/adjust_gamma_op.h"
#include <memory>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
const float AdjustGammaOp::kGain = 1.0;
Status AdjustGammaOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// typecast
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
"AdjustGamma: input tensor type should be [int, float, double], but got string.");
if (input->type().IsFloat()) {
std::shared_ptr<Tensor> input_tensor;
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
return AdjustGamma(input_tensor, output, gamma_, gain_);
} else {
return AdjustGamma(input, output, gamma_, gain_);
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class AdjustGammaOp : public TensorOp {
public:
/// Default gain to be used
static const float kGain;
AdjustGammaOp(const float &gamma, const float &gain) : gamma_(gamma), gain_(gain) {}
~AdjustGammaOp() override = default;
/// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const AdjustGammaOp &so) {
so.Print(out);
return out;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kAdjustGammaOp; }
private:
float gamma_;
float gain_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_ADJUST_GAMMA_OP_H_

View File

@ -872,6 +872,64 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
return Status::OK(); return Status::OK();
} }
Status AdjustGamma(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &gamma,
const float &gain) {
try {
int num_channels = 1;
if (input->Rank() < 2) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: image shape is not <...,H,W,C> or <H,W>.");
}
if (input->Rank() > 2) {
num_channels = input->shape()[-1];
}
if (num_channels != 1 && num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: channel of input image should be 1 or 3.");
}
if (input->type().IsFloat()) {
for (auto itr = input->begin<float>(); itr != input->end<float>(); itr++) {
*itr = pow((*itr) * gain, gamma);
*itr = std::min(std::max((*itr), 0.0f), 1.0f);
}
*output = input;
} else {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: load image failed.");
}
cv::Mat input_img = input_cv->mat();
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
uchar LUT[256] = {};
for (int i = 0; i < 256; i++) {
float f = i / 255.0;
f = pow(f, gamma);
LUT[i] = static_cast<uchar>(floor(std::min(f * (255.0 + 1 - 1e-3) * gain, 255.0)));
}
if (input_img.channels() == 1) {
cv::MatIterator_<uchar> it = input_img.begin<uchar>();
cv::MatIterator_<uchar> it_end = input_img.end<uchar>();
for (; it != it_end; ++it) {
*it = LUT[(*it)];
}
} else {
cv::MatIterator_<cv::Vec3b> it = input_img.begin<cv::Vec3b>();
cv::MatIterator_<cv::Vec3b> it_end = input_img.end<cv::Vec3b>();
for (; it != it_end; ++it) {
(*it)[0] = LUT[(*it)[0]];
(*it)[1] = LUT[(*it)[1]];
(*it)[2] = LUT[(*it)[2]];
}
}
output_cv->mat() = input_img * 1;
*output = std::static_pointer_cast<Tensor>(output_cv);
}
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("AdjustGamma: " + std::string(e.what()));
}
return Status::OK();
}
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff, Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
const std::vector<uint32_t> &ignore) { const std::vector<uint32_t> &ignore) {
try { try {

View File

@ -234,6 +234,16 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff, Status AutoContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &cutoff,
const std::vector<uint32_t> &ignore); const std::vector<uint32_t> &ignore);
/// \brief Returns image with gamma correction.
/// \param[in] input: Tensor of shape <H,W,3>/<H,W,1>/<H,W> in RGB/Grayscale and any OpenCV compatible type,
/// see CVTensor.
/// \param[in] gamma: Non negative real number, same as gamma in the equation. gamma larger than 1 make the shadows
/// darker, while gamma smaller than 1 make dark regions lighter.
/// \param[in] gain: The constant multiplier.
/// \param[out] output: Adjusted image of same shape and type.
Status AdjustGamma(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &gamma,
const float &gain);
/// \brief Returns image with adjusted saturation. /// \brief Returns image with adjusted saturation.
/// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor. /// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param alpha: Alpha value to adjust saturation by. Should be a positive number. /// \param alpha: Alpha value to adjust saturation by. Should be a positive number.

View File

@ -38,6 +38,11 @@ Status ValidateFloatScalarPositive(const std::string &op_name, const std::string
return Status::OK(); return Status::OK();
} }
Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, false));
return Status::OK();
}
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) { Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) { if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
std::string err_msg = std::string err_msg =

View File

@ -36,6 +36,9 @@ Status ValidateIntScalarPositive(const std::string &op_name, const std::string &
// Helper function to positive float scalar // Helper function to positive float scalar
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar); Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to non-negative float scalar
Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to validate scalar // Helper function to validate scalar
template <typename T> template <typename T>
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar, Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,

View File

@ -2,6 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
set(DATASET_KERNELS_IR_VISION_SRC_FILES set(DATASET_KERNELS_IR_VISION_SRC_FILES
adjust_gamma_ir.cc
affine_ir.cc affine_ir.cc
auto_contrast_ir.cc auto_contrast_ir.cc
bounding_box_augment_ir.cc bounding_box_augment_ir.cc

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/adjust_gamma_op.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace vision {
#ifndef ENABLE_ANDROID
// AdjustGammaOperation
AdjustGammaOperation::AdjustGammaOperation(float gamma, float gain) : gamma_(gamma), gain_(gain) {}
Status AdjustGammaOperation::ValidateParams() {
// gamma
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("AdjustGamma", "gamma", gamma_));
return Status::OK();
}
std::shared_ptr<TensorOp> AdjustGammaOperation::Build() {
std::shared_ptr<AdjustGammaOp> tensor_op = std::make_shared<AdjustGammaOp>(gamma_, gain_);
return tensor_op;
}
Status AdjustGammaOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["gamma"] = gamma_;
args["gain"] = gain_;
*out_json = args;
return Status::OK();
}
#endif
} // namespace vision
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace vision {
constexpr char kAdjustGammaOperation[] = "AdjustGamma";
class AdjustGammaOperation : public TensorOperation {
public:
explicit AdjustGammaOperation(float gamma, float gain);
~AdjustGammaOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kAdjustGammaOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float gamma_;
float gain_;
};
} // namespace vision
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VISION_ADJUST_GAMMA_IR_H_

View File

@ -53,6 +53,7 @@ namespace dataset {
constexpr char kTensorOp[] = "TensorOp"; constexpr char kTensorOp[] = "TensorOp";
// image // image
constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
constexpr char kAffineOp[] = "AffineOp"; constexpr char kAffineOp[] = "AffineOp";
constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kAutoContrastOp[] = "AutoContrastOp";
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";

View File

@ -593,9 +593,18 @@ bool TaskEmitAction(const ResourcePtr &res) {
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false); context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
} else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kAscendDevice && backend != kMsVm) { auto manager = func_graph->manager();
auto graphs = manager->func_graphs();
bool exist_while =
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
MS_LOG(INFO) << "Run graph mode with multigraph sink.";
bc_ptr->set_is_multi_graph_sink(true); bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
} else {
MS_LOG(INFO) << "Run graph mode with vm.";
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} }
} }

View File

@ -290,6 +290,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.merge_addn_, irpass.merge_addn_,
irpass.float_tuple_getitem_switch_, irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_, irpass.float_env_getitem_switch_,
irpass.inline_,
irpass.incorporate_getitem_set_, irpass.incorporate_getitem_set_,
irpass.incorporate_call_, irpass.incorporate_call_,
irpass.incorporate_call_switch_, irpass.incorporate_call_switch_,

View File

@ -142,20 +142,21 @@ std::string GetCompileExceptionInfo() {
return oss.str(); return oss.str();
} }
void SetGpuLoopSink(const ResourcePtr &resource) { void SetLoopCount(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(resource);
auto func_graph = resource->func_graph(); auto func_graph = resource->func_graph();
if (func_graph != nullptr && func_graph->manager() != nullptr) { if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager(); auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size(); size_t graph_nums = manager->func_graphs().size();
int64_t sinksize = ConfigManager::GetInstance().iter_num(); int64_t loop_size = ConfigManager::GetInstance().iter_num();
if (graph_nums == 1 || MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) { const auto context_ptr = MsContext::GetInstance();
resource->set_gpu_loopsink(true, sinksize); if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
} else { resource->set_vm_loop(!context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK), loop_size);
resource->set_gpu_loopsink(false, sinksize); } else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
bool run_with_mind_rt = graph_nums == 1 || context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
resource->set_vm_loop(!run_with_mind_rt, loop_size);
} }
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource->gpu_loopsink_flag() << ", set loopsink size to " MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
<< sinksize;
} }
} }
@ -206,7 +207,8 @@ void CacheFuncGraph(const ResourcePtr &resource) {
ChangeFileMode(realpath.value(), S_IRWXU); ChangeFileMode(realpath.value(), S_IRWXU);
std::ofstream fout(realpath.value()); std::ofstream fout(realpath.value());
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!"; MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
} }
FuncGraphPtr fg = resource->func_graph(); FuncGraphPtr fg = resource->func_graph();
mind_ir::ModelProto fg_model = GetBinaryProto(fg, true); mind_ir::ModelProto fg_model = GetBinaryProto(fg, true);
@ -825,7 +827,7 @@ void Pipeline::Run(const std::string &phase_s) {
MS_LOG(DEBUG) << "Action " << action.first << " end."; MS_LOG(DEBUG) << "Action " << action.first << " end.";
}; };
if (action.first == "task_emit") { if (action.first == "task_emit") {
SetGpuLoopSink(resource_); SetLoopCount(resource_);
} else if (action.first == "validate") { } else if (action.first == "validate") {
CacheValidateFuncGraph(phase_s, resource_); CacheValidateFuncGraph(phase_s, resource_);
} }
@ -1001,13 +1003,17 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s;
} }
// Set loopsink size for each phase. // Set loopsink size for each phase.
bool is_loopsink = info_[phase_s]->resource->gpu_loopsink_flag(); bool vm_loop_flag = info_[phase_s]->resource->vm_loop_flag();
int64_t sinksize = info_[phase_s]->resource->gpu_loopsink_size(); int64_t loop_size = info_[phase_s]->resource->loop_size();
ConfigManager::GetInstance().set_gpu_loopsink_size(is_loopsink ? sinksize : 1); int64_t vm_loop = 1;
// If target is not gpu or is loopsink, keep vmloop 1. if (vm_loop_flag) {
bool g = (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice); vm_loop = loop_size;
int64_t vm_loop = (!g || is_loopsink) ? 1 : sinksize; } else {
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << (is_loopsink ? sinksize : 1); // Set the loop size in config if graphs nums is 1(is_loop_sin=True), then there will be a loop embrace
// 'Execute(graph)' in GPUSession.
ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
}
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
py::object ret; py::object ret;
MS_LOG(DEBUG) << "Eval run" << backend; MS_LOG(DEBUG) << "Eval run" << backend;
for (int64_t i = 0; i < vm_loop; i++) { for (int64_t i = 0; i < vm_loop; i++) {
@ -1157,9 +1163,6 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// Convert CNodeList to LinConvertResult. // Convert CNodeList to LinConvertResult.
auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false); auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
auto runner = convert_fn(segment, ""); auto runner = convert_fn(segment, "");
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id);
}
ConfigManager::GetInstance().set_iter_num(size); ConfigManager::GetInstance().set_iter_num(size);
// PS cache does not support loop sink. // PS cache does not support loop sink.
#if ((defined ENABLE_CPU) && (!defined _WIN32)) #if ((defined ENABLE_CPU) && (!defined _WIN32))

View File

@ -75,12 +75,12 @@ class Resource : public ResourceBase {
const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; }
void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; }
void set_gpu_loopsink(const bool &flag, const int64_t size) { void set_vm_loop(const bool &flag, const int64_t size) {
gpu_loopsink_flag_ = flag; vm_loop_flag_ = flag;
gpu_loopsink_size_ = size; loop_size_ = size;
} }
bool gpu_loopsink_flag() { return gpu_loopsink_flag_; } bool vm_loop_flag() { return vm_loop_flag_; }
int64_t gpu_loopsink_size() { return gpu_loopsink_size_; } int64_t loop_size() { return loop_size_; }
// Reclaim resource and clear the cache. // Reclaim resource and clear the cache.
// ExecutorPy::Compile() can be called multiple times, so cache // ExecutorPy::Compile() can be called multiple times, so cache
// should be cleared. // should be cleared.
@ -92,8 +92,8 @@ class Resource : public ResourceBase {
abstract::AbstractBasePtrList args_spec_; abstract::AbstractBasePtrList args_spec_;
py::object input_; py::object input_;
bool is_cleaned_; bool is_cleaned_;
bool gpu_loopsink_flag_{false}; bool vm_loop_flag_{false};
int64_t gpu_loopsink_size_{1}; int64_t loop_size_{1};
}; };
using ResourcePtr = std::shared_ptr<pipeline::Resource>; using ResourcePtr = std::shared_ptr<pipeline::Resource>;

View File

@ -76,15 +76,17 @@ class OrderEnforcer {
} }
} }
bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) { std::unordered_set<AnfNodePtr> CheckMakeTupleHaveLoad(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::unordered_set<AnfNodePtr> loads;
auto inputs = cnode->inputs(); auto inputs = cnode->inputs();
for (size_t index = 1; index < inputs.size(); index++) { for (size_t index = 1; index < inputs.size(); index++) {
auto input = cnode->input(index); auto input = cnode->input(index);
if (IsPrimitiveCNode(input, prim::kPrimLoad)) { if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
return true; loads.insert(input);
} }
} }
return false; return loads;
} }
std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) { std::vector<AnfNodePtr> FindUpdateStateUsers(const CNodePtr &cnode) {
@ -155,23 +157,31 @@ class OrderEnforcer {
// u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs // u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs
// assign = Assign(para2, inputs, u3) // assign = Assign(para2, inputs, u3)
void HandleMakeTupleUsers(const AnfNodePtr &node) { void HandleMakeTupleUsers(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto maketuple = node->cast<CNodePtr>(); auto maketuple = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple); MS_EXCEPTION_IF_NULL(maketuple);
if (CheckMakeTupleHaveLoad(maketuple)) { std::unordered_set<AnfNodePtr> loads = CheckMakeTupleHaveLoad(maketuple);
if (!loads.empty()) {
auto update_state = FindLastUpdateState(maketuple); auto update_state = FindLastUpdateState(maketuple);
if (update_state != nullptr) { if (update_state != nullptr) {
std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple); std::unordered_set<AnfNodePtr> maketuple_users = GetSpecialOperatorRealUsers(maketuple);
std::unordered_set<AnfNodePtr> no_push_maketuple_users; std::unordered_set<AnfNodePtr> no_push_all_users;
// Push and Pull at the end of the execution order, // Push and Pull at the end of the execution order,
// In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate // In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate
for (auto maketuple_user : maketuple_users) { for (auto maketuple_user : maketuple_users) {
if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) { if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) {
no_push_maketuple_users.insert(maketuple_user); no_push_all_users.insert(maketuple_user);
}
}
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = GetSpecialOperatorRealUsers(load);
for (auto load_user : load_users) {
no_push_all_users.insert(load_user);
} }
} }
auto update_state_cnode = update_state->cast<CNodePtr>(); auto update_state_cnode = update_state->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(update_state_cnode); MS_EXCEPTION_IF_NULL(update_state_cnode);
AddInputEdges(update_state_cnode, no_push_maketuple_users); AddInputEdges(update_state_cnode, no_push_all_users);
} }
} }
} }
@ -265,6 +275,8 @@ class OrderEnforcer {
// Add load users as input edges of the update_state node. // Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) { void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users); auto sorted_load_users = SortLoadUsers(load_users);
auto inputs = update_state->inputs();
size_t origin_size = inputs.size();
for (auto &load_user : sorted_load_users) { for (auto &load_user : sorted_load_users) {
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
continue; continue;
@ -272,10 +284,16 @@ class OrderEnforcer {
if (!IsDependOn(load_user, update_state)) { if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user); processed_nodes_.insert(load_user);
if (!IsInUpdateState(load_user, update_state)) { if (!IsInUpdateState(load_user, update_state)) {
manager_->AddEdge(update_state, load_user); inputs.emplace_back(load_user);
} }
} }
} }
if (inputs.size() > origin_size) {
auto new_update_state = func_graph_->NewCNode(inputs);
new_update_state->set_abstract(update_state->abstract());
new_update_state->set_scope(update_state->scope());
manager_->Replace(update_state, new_update_state);
}
} }
// Sort load users by their topo sort order. // Sort load users by their topo sort order.

View File

@ -47,7 +47,6 @@
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
#include "debug/env_config_parser.h" #include "debug/env_config_parser.h"
#endif #endif
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h" #include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h"
@ -250,6 +249,8 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
// DestroyHccl must be called before FreeDeviceMemory
(void)DestroyHccl();
if (mem_manager_ != nullptr) { if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory(); mem_manager_->FreeDeviceMemory();
} }
@ -259,7 +260,6 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret; MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
} }
(void)DestroyHccl();
(void)ResetDevice(device_id); (void)ResetDevice(device_id);
(void)ProfilingManager::GetInstance().StopProfiling(); (void)ProfilingManager::GetInstance().StopProfiling();
current_graph_ = nullptr; current_graph_ = nullptr;
@ -291,9 +291,7 @@ bool AscendKernelRuntime::Init() {
MS_LOG(WARNING) << "Init ErrorManager failed."; MS_LOG(WARNING) << "Init ErrorManager failed.";
} }
try { try {
OpTilingCalculater::GetInstance().Init();
// Start up profiling before rtSetDevice // Start up profiling before rtSetDevice
bool ret = InitDevice(); bool ret = InitDevice();
if (!ret) { if (!ret) {
return ret; return ret;
@ -876,7 +874,7 @@ bool AscendKernelRuntime::HcclInit() {
} }
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str, bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
full_path); full_path, mode == kGraphMode);
free(full_path); free(full_path);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "Hcom init failed."; MS_LOG(ERROR) << "Hcom init failed.";

View File

@ -1992,6 +1992,28 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr>
return recv_node_ptr; return recv_node_ptr;
} }
bool AscendStreamAssign::IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node,
const CNodePtr &cur_node, bool exclude_hcom) {
MS_EXCEPTION_IF_NULL(nop_node);
auto cnode = nop_node->cast<CNodePtr>();
auto new_inputs = cnode->inputs();
for (size_t i = 1; i < new_inputs.size(); i++) {
if (opt::IsNopNode(new_inputs[i])) {
if (IsNopNodeTarget(new_inputs[i], target_node, cur_node, exclude_hcom)) {
return true;
}
} else {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[i], 0);
if (target_node == new_real_input.first) {
if (!(exclude_hcom && IsHcom(cur_node))) {
return true;
}
}
}
}
return false;
}
vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
vector<CNodePtr>::iterator end, const CNodePtr &node, vector<CNodePtr>::iterator end, const CNodePtr &node,
bool exclude_hcom) { bool exclude_hcom) {
@ -2000,19 +2022,9 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
for (size_t i = 1; i < inputs.size(); i++) { for (size_t i = 1; i < inputs.size(); i++) {
auto input = inputs[i]; auto input = inputs[i];
if (opt::IsNopNode(input)) { if (opt::IsNopNode(input)) {
CNodePtr cnode = input->cast<CNodePtr>(); if (IsNopNodeTarget(input, node, *begin, exclude_hcom)) {
auto new_inputs = cnode->inputs();
for (size_t j = 1; j < new_inputs.size(); j++) {
auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
// find target node except hcom op. insert event for hcom in:InsertEventHcomDependCommonBak function
// only insert one time
if (node == new_real_input.first) {
if (!(exclude_hcom && IsHcom(*begin))) {
MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
return begin; return begin;
} }
}
}
} else { } else {
auto real_input = AnfAlgo::VisitKernel(input, 0); auto real_input = AnfAlgo::VisitKernel(input, 0);
if (node == real_input.first) { if (node == real_input.first) {

View File

@ -175,7 +175,8 @@ class AscendStreamAssign {
uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key); uint32_t GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key);
uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr); uint32_t GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr); void GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr);
bool IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node, const CNodePtr &cur_node,
bool exclude_hcom);
bool IsTaskSink(); bool IsTaskSink();
bool IsHcom(const CNodePtr &cur_cnode_ptr); bool IsHcom(const CNodePtr &cur_cnode_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr); bool IsIndependentNode(const CNodePtr &node_ptr);

View File

@ -133,9 +133,11 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
} }
uint32_t graph_id = kernel_graph_->graph_id(); uint32_t graph_id = kernel_graph_->graph_id();
uint32_t rank_id = 0; uint32_t rank_id = 0;
auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto env_rank_id = common::GetEnv("RANK_ID"); auto env_rank_id = common::GetEnv("RANK_ID");
if (!(env_table_file.empty() || env_rank_id.empty())) { if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
// get actual rank id if it's distribution training case. // get actual rank id if it's distribution training case.
if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) { if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
MS_LOG(INFO) << "Failed to get rank id."; MS_LOG(INFO) << "Failed to get rank id.";

View File

@ -19,12 +19,12 @@
#include <memory> #include <memory>
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include "register/op_tiling.h" #include "register/op_tiling.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "pipeline/jit/static_analysis/static_analysis.h" #include "pipeline/jit/static_analysis/static_analysis.h"
#include "runtime/device/ascend/executor/tiling/op_tiling_adapter.h"
#include "common/trans.h" #include "common/trans.h"
namespace mindspore { namespace mindspore {
@ -131,14 +131,17 @@ void AiCoreDynamicKernel::ComputeTiling() {
auto cnode = cnode_ptr_.lock(); auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "Start compute tiling of:" << cnode->fullname_with_scope(); MS_LOG(INFO) << "Start compute tiling of:" << cnode->fullname_with_scope();
optiling::OpRunInfo op_run_info; // start compute tiling
optiling::utils::OpRunInfo op_run_info_v2(-1, true, 0);
tiling::OpTilingCalculateAdapter converter;
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>("default");
auto ge_node = converter.AnfNodeToGeNodeAdapter(cnode, &ge_graph, depend_tensor_map_);
(void)optiling::OpParaCalculateV2(*ge_node, op_run_info_v2);
OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode), op_compile_info_, depend_tensor_map_, block_dim_ = op_run_info_v2.GetBlockDim();
NOT_NULL(&op_run_info)); op_run_info_v2.GetAllWorkspaces(workspaces_size_);
block_dim_ = op_run_info.block_dim; tiling_data_ = op_run_info_v2.GetAllTilingData().str();
workspaces_size_ = op_run_info.workspaces; tiling_key_ = op_run_info_v2.GetTilingKey();
tiling_data_ = op_run_info.tiling_data.str();
tiling_key_ = op_run_info.tiling_key;
} }
void AiCoreDynamicKernel::AllocateWorkspace() { void AiCoreDynamicKernel::AllocateWorkspace() {

View File

@ -0,0 +1,292 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/tiling/op_tiling_adapter.h"
#include <algorithm>
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/ge_types_convert.h"
#include "utils/utils.h"
#include "external/graph/tensor.h"
#include "external/register/op_tiling_registry.h"
#include "graph/utils/graph_utils.h"
#include "common/ge_inner_error_codes.h"
#include "graph/utils/op_desc_utils.h"
namespace mindspore {
namespace device {
namespace tiling {
constexpr auto COMPILE_INFO_KEY = "compile_info_key";
constexpr auto COMPILE_INFO_JSON = "compile_info_json";
constexpr auto ATOMIC_COMPILE_INFO_KEY = "_atomic_compile_info_key";
constexpr auto ATOMIC_COMPILE_INFO_JSON = "_atomic_compile_info_json";
constexpr auto ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends";
constexpr auto CONSTANTOP = "Constant";
constexpr auto ATTR_NAME_WEIGHTS = "value";
constexpr auto PARAM_DYNAMIC = "dynamic";
std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type) {
static const std::map<std::string, std::string> kOpTypeMap = {
{"SparseApplyFtrl", "SparseApplyFtrlD"},
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
{"SparseGatherV2", "Gather"},
{"Pad", "PadD"},
{"Concat", "ConcatD"},
{"Softmax", "SoftmaxV2"},
{"DropoutDoMask", "DropOutDoMask"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {
return op_type;
}
return iter->second;
}
std::string OpTilingCalculateAdapter::GetOutputName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (output_names_.size() < index) {
return "unknown_name";
}
return output_names_[index];
}
std::string OpTilingCalculateAdapter::GetInputName(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
if (input_names_.size() < index) {
return "unknown_name";
}
return input_names_[index];
}
void OpTilingCalculateAdapter::ConvertInputShapeAndType(const CNodePtr &node, ge::OpDescPtr *op_desc) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(*op_desc);
auto input_size = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_size; i++) {
// ms info
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(node, i);
auto input_node = input_node_with_idx.first;
auto input_index = input_node_with_idx.second;
auto ms_ori_shape = AnfAlgo::GetOutputInferShape(input_node, input_index);
auto ms_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
auto ms_format = AnfAlgo::GetOutputFormat(input_node, input_index);
auto ms_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index);
// ge info
std::vector<int64_t> ge_shape;
std::vector<int64_t> ge_ori_shape;
ge::DataType ge_dtype = ascend::GeTypesConvert::TransTypeIdToGeDataType(ms_dtype);
ge::Format ge_format = ascend::GeTypesConvert::GetGeFormat(ms_format, ms_shape.size());
std::transform(ms_shape.begin(), ms_shape.end(), std::back_inserter(ge_shape),
[](size_t s) { return static_cast<int64_t>(s); });
std::transform(ms_ori_shape.begin(), ms_ori_shape.end(), std::back_inserter(ge_ori_shape),
[](size_t s) { return static_cast<int64_t>(s); });
auto input_name = GetInputName(node, i);
ge::GeTensorDesc ge_tensor_desc;
ge_tensor_desc.SetFormat(ge_format);
ge_tensor_desc.SetDataType(ge_dtype);
ge_tensor_desc.SetShape(ge::GeShape(ge_shape));
ge_tensor_desc.SetOriginShape(ge::GeShape(ge_ori_shape));
ge_tensor_desc.SetName(input_name);
(void)(*op_desc)->AddInputDesc(input_name, ge_tensor_desc);
}
}
void OpTilingCalculateAdapter::ConvertOutputShapeAndType(const CNodePtr &node, ge::OpDescPtr *op_desc) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(*op_desc);
auto output_size = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_size; i++) {
auto ms_shape = AnfAlgo::GetOutputDeviceShape(node, i);
auto ms_ori_shape = AnfAlgo::GetOutputInferShape(node, i);
auto ms_format = AnfAlgo::GetOutputFormat(node, i);
auto ms_dtype = AnfAlgo::GetOutputDeviceDataType(node, i);
std::vector<int64_t> ge_shape;
std::vector<int64_t> ge_ori_shape;
ge::DataType ge_dtype = ascend::GeTypesConvert::TransTypeIdToGeDataType(ms_dtype);
ge::Format ge_format = ascend::GeTypesConvert::GetGeFormat(ms_format, ms_shape.size());
std::transform(ms_shape.begin(), ms_shape.end(), std::back_inserter(ge_shape),
[](size_t s) { return static_cast<int64_t>(s); });
std::transform(ms_ori_shape.begin(), ms_ori_shape.end(), std::back_inserter(ge_ori_shape),
[](size_t s) { return static_cast<int64_t>(s); });
auto output_name = GetOutputName(node, i);
ge::GeTensorDesc ge_tensor_desc;
ge_tensor_desc.SetFormat(ge_format);
ge_tensor_desc.SetDataType(ge_dtype);
ge_tensor_desc.SetShape(ge::GeShape(ge_shape));
ge_tensor_desc.SetOriginShape(ge::GeShape(ge_ori_shape));
ge_tensor_desc.SetName(output_name);
(void)(*op_desc)->AddOutputDesc(output_name, ge_tensor_desc);
}
}
void OpTilingCalculateAdapter::ConvertCompileInfo(const CNodePtr &node, ge::OpDescPtr *op_desc) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(*op_desc);
if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, node)) {
MS_LOG(EXCEPTION) << "Get compile_info failed";
}
auto compile_info_attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrCompileInfo);
MS_LOG(INFO) << "For op " << op_name_ << ", get compile_info: " << compile_info_attr;
std::string compile_info_key = std::to_string(std::hash<std::string>()(compile_info_attr));
(void)ge::AttrUtils::SetStr(*(*op_desc), COMPILE_INFO_KEY, compile_info_key);
(void)ge::AttrUtils::SetStr(*(*op_desc), COMPILE_INFO_JSON, compile_info_attr);
}
ge::NodePtr OpTilingCalculateAdapter::NewConstantOp(const CNodePtr &node, const std::string &name,
const tensor::TensorPtr &tensor_data, ge::ComputeGraphPtr *ge_graph,
size_t index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(*ge_graph);
MS_EXCEPTION_IF_NULL(tensor_data);
ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(name, CONSTANTOP);
auto ms_format = AnfAlgo::GetPrevNodeOutputFormat(node, index);
auto ms_ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
auto ms_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(node, index);
auto ms_shape = AnfAlgo::GetInputDeviceShape(node, index);
std::vector<int64_t> ge_shape;
std::vector<int64_t> ge_ori_shape;
std::transform(ms_shape.begin(), ms_shape.end(), std::back_inserter(ge_shape),
[](size_t s) { return static_cast<int64_t>(s); });
std::transform(ms_ori_shape.begin(), ms_ori_shape.end(), std::back_inserter(ge_ori_shape),
[](size_t s) { return static_cast<int64_t>(s); });
ge::DataType ge_dtype = ascend::GeTypesConvert::TransTypeIdToGeDataType(ms_dtype);
ge::Format ge_format = ascend::GeTypesConvert::GetGeFormat(ms_format, ms_shape.size());
ge::GeTensorDesc ge_tensor_desc;
ge_tensor_desc.SetFormat(ge_format);
ge_tensor_desc.SetDataType(ge_dtype);
ge_tensor_desc.SetShape(ge::GeShape(ge_shape));
ge_tensor_desc.SetOriginShape(ge::GeShape(ge_ori_shape));
ge_tensor_desc.SetName(name);
ge::GeTensorPtr ge_tensor = std::make_shared<ge::GeTensor>(
ge_tensor_desc, static_cast<uint8_t *>(tensor_data->data_c()), IntToSize(tensor_data->DataSize()));
(void)op_desc->AddOutputDesc(name, ge_tensor_desc);
ge::NodePtr constant_op = (*ge_graph)->AddNode(op_desc);
ge::OpDescUtils::SetWeights(constant_op, {ge_tensor});
(void)ge::AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor);
return constant_op;
}
std::vector<std::tuple<std::size_t, ge::NodePtr>> OpTilingCalculateAdapter::ConvertDepends(
const CNodePtr &node, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map, ge::OpDescPtr *op_desc,
ge::ComputeGraphPtr *ge_graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(*op_desc);
auto depends_list_me = abstract::GetDependsFormMap(node);
if (depends_list_me.empty()) {
MS_LOG(INFO) << "The node " << op_name_ << " has no infer depend ";
return {};
}
auto has_input_name_attr = AnfAlgo::HasNodeAttr("input_names", node);
if (!has_input_name_attr) {
MS_LOG(EXCEPTION) << "Node should has attr: input_names. " << node->fullname_with_scope();
}
auto input_names_attr = AnfAlgo ::GetNodeAttr<std::vector<std::string>>(node, "input_names");
std::vector<std::string> op_infer_depends;
std::vector<std::tuple<std::size_t, ge::NodePtr>> constant_ops;
for (auto index : depends_list_me) {
if (LongToSize(index) > input_names_attr.size()) {
MS_LOG(EXCEPTION) << "Input index " << index << " should less input_names' size " << input_names_attr.size();
}
auto iter = depend_tensor_map.find(LongToSize(index));
if (iter == depend_tensor_map.end()) {
MS_LOG(EXCEPTION) << "Input index " << index << " should less than depend_tensor_map' size "
<< input_names_attr.size();
}
auto depend_name = input_names_attr[index];
auto const_tensor = iter->second;
ge::NodePtr ge_constant_op = NewConstantOp(node, depend_name, const_tensor, ge_graph, index);
constant_ops.emplace_back(std::tuple<std::size_t, ge::NodePtr>(index, ge_constant_op));
op_infer_depends.emplace_back(depend_name);
}
(void)(*op_desc)->SetOpInferDepends(op_infer_depends);
return constant_ops;
}
void OpTilingCalculateAdapter::AddEdge(const ge::NodePtr &ge_node,
const std::vector<std::tuple<std::size_t, ge::NodePtr>> &constant_ops) {
MS_EXCEPTION_IF_NULL(ge_node);
MS_LOG(INFO) << "Add edge for op " << op_name_;
for (const auto &item : constant_ops) {
auto index = std::get<0>(item);
auto constant_op = std::get<1>(item);
(void)ge_node->AddLinkFrom(index, constant_op);
}
}
void OpTilingCalculateAdapter::InitOpIoName(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(INFO) << "Get the every input name of " << op_name_;
auto op_info_ptr = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name_, node);
MS_EXCEPTION_IF_NULL(op_info_ptr);
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
auto inputs_ptr = op_info_ptr->inputs_ptr();
size_t dynamic_input_index = 0;
std::vector<int64_t> dynamic_inputs_list;
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dynamic_inputs_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
}
for (const auto &item : inputs_ptr) {
MS_EXCEPTION_IF_NULL(item);
if (item->param_type() == PARAM_DYNAMIC) {
if (dynamic_input_index > dynamic_inputs_list.size()) {
MS_LOG(EXCEPTION) << "Dynamic input index should less than the dynamic input's size.";
}
auto real_inputs_num = dynamic_inputs_list[dynamic_input_index];
for (auto k = 0; k < real_inputs_num; k++) {
std::string input_name = item->name() + "_dynamic_" + std::to_string(k);
input_names_.emplace_back(input_name);
}
dynamic_input_index++;
} else {
input_names_.emplace_back(item->name());
}
}
// output io names
auto outputs_ptr = op_info_ptr->outputs_ptr();
for (const auto &out_item : outputs_ptr) {
MS_EXCEPTION_IF_NULL(out_item);
output_names_.emplace_back(out_item->name());
}
}
ge::NodePtr OpTilingCalculateAdapter::AnfNodeToGeNodeAdapter(
const CNodePtr &node, ge::ComputeGraphPtr *ge_graph, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map) {
MS_EXCEPTION_IF_NULL(node);
op_name_ = AnfAlgo::GetCNodeName(node);
auto op_type = GetRealOpType(op_name_);
(void)InitOpIoName(node);
ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(op_name_, op_type);
MS_EXCEPTION_IF_NULL(op_desc);
ConvertInputShapeAndType(node, &op_desc);
ConvertOutputShapeAndType(node, &op_desc);
ConvertCompileInfo(node, &op_desc);
std::vector<std::tuple<std::size_t, ge::NodePtr>> constant_ops =
ConvertDepends(node, depend_tensor_map, &op_desc, ge_graph);
MS_EXCEPTION_IF_NULL(*ge_graph);
auto ge_node = (*ge_graph)->AddNode(op_desc);
MS_EXCEPTION_IF_NULL(ge_node);
AddEdge(ge_node, constant_ops);
return ge_node;
}
} // namespace tiling
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_ADAPTER_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_ADAPTER_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include "backend/kernel_compiler/oplib/opinfo.h"
#include "register/op_tiling.h"
#include "external/graph/operator.h"
#include "graph/utils/graph_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace device {
namespace tiling {
class OpTilingCalculateAdapter {
public:
OpTilingCalculateAdapter() = default;
~OpTilingCalculateAdapter() = default;
ge::NodePtr AnfNodeToGeNodeAdapter(const CNodePtr &node, ge::ComputeGraphPtr *ge_graph,
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map);
private:
void ConvertInputShapeAndType(const CNodePtr &node, ge::OpDescPtr *op_desc);
void ConvertOutputShapeAndType(const CNodePtr &node, ge::OpDescPtr *op_desc);
void ConvertCompileInfo(const CNodePtr &node, ge::OpDescPtr *op_desc);
void ConvertAttrs(const CNodePtr &node, ge::OpDescPtr *op_desc);
std::vector<std::tuple<std::size_t, ge::NodePtr>> ConvertDepends(
const CNodePtr &node, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map, ge::OpDescPtr *op_desc,
ge::ComputeGraphPtr *ge_graph);
ge::NodePtr NewConstantOp(const CNodePtr &node, const std::string &name, const tensor::TensorPtr &tensor_data,
ge::ComputeGraphPtr *ge_graph, size_t index);
void AddEdge(const ge::NodePtr &ge_node, const std::vector<std::tuple<std::size_t, ge::NodePtr>> &constant_ops);
std::string GetRealOpType(const std::string &op_type);
std::string GetInputName(const CNodePtr &node, size_t index);
std::string GetOutputName(const CNodePtr &node, size_t index);
void InitOpIoName(const CNodePtr &node);
std::string op_name_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
};
} // namespace tiling
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_ADAPTER_H_

View File

@ -1,205 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include <dlfcn.h>
#include <map>
#include <vector>
#include <memory>
#include <string>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/ge_types_convert.h"
#include "utils/utils.h"
#include "external/graph/tensor.h"
#include "external/register/op_tiling_registry.h"
namespace mindspore {
namespace device {
namespace ascend {
ge::Tensor MakeTempGeTensor(const TypeId &type_id, const std::vector<size_t> &shape, const std::string &format) {
auto ge_type = GeTypesConvert::TransTypeIdToGeDataType(type_id);
std::vector<int64_t> int_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(int_shape), SizeToLong);
auto ge_format = GeTypesConvert::GetGeFormat(format, shape.size());
ge::Tensor ge_tensor;
ge_tensor.SetTensorDesc(ge::TensorDesc(ge::Shape(int_shape), ge_format, ge_type));
return ge_tensor;
}
void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode,
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
MS_LOG(INFO) << "FeedTeOpTensorInputArg start, node:" << cnode->fullname_with_scope();
auto input_size = AnfAlgo::GetInputTensorNum(cnode.get());
// Skip Dynamic Shape Depend Input
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode.get(), i);
auto input_node = input_node_with_index.first;
auto input_index = input_node_with_index.second;
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
auto output_ori_shape = AnfAlgo::GetOutputInferShape(input_node, input_index);
auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index);
auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index);
auto iter = type_name_map.find(output_dtype);
if (iter == type_name_map.end()) {
MS_LOG(EXCEPTION) << "Cannot found typeId:" << output_dtype;
}
auto ge_output_dtype = iter->second;
optiling::TeOpTensorArg tensor_arg;
optiling::TeOpTensor tensor;
tensor_arg.arg_type = optiling::TA_SINGLE;
tensor.dtype = ge_output_dtype;
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end());
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
tensor_arg.tensor.emplace_back(tensor);
tensor_arg_list->emplace_back(tensor_arg);
}
}
void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
MS_LOG(INFO) << "FeedTeOpTensorOutputArg start, node:" << cnode->fullname_with_scope();
auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get());
for (size_t i = 0; i < output_size; ++i) {
auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i);
auto output_ori_shape = AnfAlgo::GetOutputInferShape(cnode.get(), i);
auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i);
auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i);
auto iter = type_name_map.find(data_type);
if (iter == type_name_map.end()) {
MS_LOG(EXCEPTION) << "Cannot found typeId:" << data_type;
}
optiling::TeOpTensorArg tensor_arg;
optiling::TeOpTensor tensor;
tensor_arg.arg_type = optiling::TA_SINGLE;
tensor.dtype = iter->second;
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
tensor.ori_shape.insert(tensor.ori_shape.end(), output_ori_shape.begin(), output_ori_shape.end());
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
tensor_arg.tensor.emplace_back(tensor);
tensor_arg_list->emplace_back(tensor_arg);
}
}
void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) {
MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope();
auto depends_list_me = abstract::GetDependsFormMap(cnode);
if (depends_list_me.empty()) {
MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope();
return;
}
std::vector<int> depends_list;
(void)std::transform(depends_list_me.begin(), depends_list_me.end(), std::back_inserter(depends_list),
[](const int64_t &value) { return static_cast<int>(value); });
for (auto index : depends_list) {
auto iter = depend_tensor_map.find(IntToSize(index));
if (iter == depend_tensor_map.end()) {
MS_LOG(EXCEPTION) << "Index not found in depend_tensor_map";
}
auto const_tensor = iter->second;
auto have_input_names_attr = AnfAlgo::HasNodeAttr("input_names", cnode);
if (!have_input_names_attr) {
MS_LOG(EXCEPTION) << "cnode:" << cnode->fullname_with_scope() << " no input_names attr";
}
auto input_names_attr = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode.get(), "input_names");
if (IntToSize(index) >= input_names_attr.size()) {
MS_LOG(EXCEPTION) << "input index" << index << " >= input_name_attr.size:" << input_names_attr.size();
}
auto input_name = input_names_attr[index];
MS_LOG(INFO) << "input_name is " << input_name;
auto type_id = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode.get(), IntToSize(index));
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode.get(), IntToSize(index));
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode.get(), IntToSize(index));
const_inputs->try_emplace(
input_name,
optiling::TeConstTensorData{static_cast<const uint8_t *>(const_tensor->data_c()),
IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id, shape, format)});
}
MS_LOG(INFO) << "FeedTeOpConstTensor end";
}
void OpTilingCalculater::Init() {
MS_LOG(INFO) << "Start init OpTilingCalculater";
tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf();
if (tiling_func_map_.empty()) {
MS_LOG(EXCEPTION) << "Get register tiling func failed.";
}
}
std::string GetRealOpType(const std::string &op_type) {
static const std::map<std::string, std::string> kOpTypeMap = {
{"SparseApplyFtrl", "SparseApplyFtrlD"},
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
{"SparseGatherV2", "Gather"},
{"Pad", "PadD"},
{"Concat", "ConcatD"},
{"Softmax", "SoftmaxV2"},
{"DropoutDoMask", "DropOutDoMask"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {
return op_type;
}
return iter->second;
}
void OpTilingCalculater::CalculateTiling(const NotNull<CNodePtr> &cnode, const optiling::OpCompileInfo &op_compile_info,
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
const NotNull<optiling::OpRunInfo *> op_run_info) {
optiling::TeOpParas op_param;
std::string op_type = AnfAlgo::GetCNodeName(cnode.get());
MS_LOG(INFO) << "[DynamicShape] calculate tiling, op_type:" << op_type;
FeedTeOpTensorInputArg(cnode, NOT_NULL(&op_param.inputs));
FeedTeOpTensorOutputArg(cnode, NOT_NULL(&op_param.outputs));
FeedTeOpConstTensor(cnode, depend_tensor_map, NOT_NULL(&op_param.const_inputs));
op_type = GetRealOpType(op_type);
auto iter = tiling_func_map_.find(op_type);
if (iter == tiling_func_map_.end()) {
iter = tiling_func_map_.find("AutoTiling");
if (iter == tiling_func_map_.end()) {
MS_LOG(EXCEPTION) << "AutoTiling Func Not Found";
}
}
MS_LOG(INFO) << "Get tiling func:" << iter->first;
if (iter != tiling_func_map_.end()) {
bool ret = (iter->second)(op_param, op_compile_info, *op_run_info);
if (!ret) {
MS_LOG(EXCEPTION) << "Calculate tiling failed";
}
} else {
MS_LOG(EXCEPTION) << "Tiling func not found";
}
MS_LOG(INFO) << "CalculateTiling success";
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -1,55 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
#include <map>
#include <memory>
#include <string>
#include "utils/ms_utils.h"
#include "utils/contract.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "register/op_tiling.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace device {
namespace ascend {
class OpTilingCalculater {
public:
static OpTilingCalculater &GetInstance() {
static OpTilingCalculater instance;
return instance;
}
void Init();
void CalculateTiling(const NotNull<CNodePtr> &cnode, const optiling::OpCompileInfo &op_compile_info,
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<optiling::OpRunInfo *> op_run_info);
private:
OpTilingCalculater() = default;
~OpTilingCalculater() = default;
DISABLE_COPY_AND_ASSIGN(OpTilingCalculater);
std::map<std::string, optiling::OpTilingFunc> tiling_func_map_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_

View File

@ -137,43 +137,53 @@ bool HcclAdapter::InitHccl() {
return true; return true;
} }
bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) { bool HcclAdapter::InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file,
MS_LOG(INFO) << "Start init hccl adapter."; bool is_graph_mode) {
MS_LOG(INFO) << "Start init hccl adapter for " << (is_graph_mode ? "graph mode." : "pynative mode.");
std::lock_guard<std::mutex> lock(init_mutex_); std::lock_guard<std::mutex> lock(init_mutex_);
if (init_flag_) { if (init_flag_) {
MS_LOG(INFO) << "Hccl has been inited, skip."; MS_LOG(INFO) << "Hccl has been inited, skip.";
return true; return true;
} }
is_graph_mode_ = is_graph_mode;
InitPlugin(); InitPlugin();
if (is_graph_mode_) {
bool ret = InitKernelInfoStore(device_id, rank_id, rank_file); bool ret = InitKernelInfoStore(device_id, rank_id, rank_file);
if (!ret) { if (!ret) {
return false; return false;
} }
ret = InitHcclComm(rank_id, rank_file);
if (!ret) {
return false;
}
ret = InitHcclExec(); ret = InitHcclExec();
if (!ret) { if (!ret) {
return false; return false;
} }
} else {
bool ret = InitHcclComm(rank_id, rank_file);
if (!ret) {
return false;
}
}
init_flag_ = true; init_flag_ = true;
MS_LOG(INFO) << "Init hccl adapter success."; MS_LOG(INFO) << "Init hccl adapter success.";
return true; return true;
} }
bool HcclAdapter::FinalizeHccl() { bool HcclAdapter::FinalizeHccl() {
MS_LOG(INFO) << "Start destroy hccl adapter.";
std::lock_guard<std::mutex> lock(init_mutex_); std::lock_guard<std::mutex> lock(init_mutex_);
MS_LOG(INFO) << "Start destroy hccl adapter for " << (is_graph_mode_ ? "graph mode." : "pynative mode.");
if (!init_flag_) { if (!init_flag_) {
MS_LOG(INFO) << "Hccl has never been inited, skip."; MS_LOG(INFO) << "Hccl has never been inited, skip.";
return true; return true;
} }
if (is_graph_mode_) {
(void)FinalizeHcclExec(); (void)FinalizeHcclExec();
(void)FinalizeHcclComm();
(void)FinalizeKernelInfoStore(); (void)FinalizeKernelInfoStore();
} else {
(void)FinalizeHcclComm();
}
FinalizePlugin(); FinalizePlugin();
init_flag_ = false; init_flag_ = false;
MS_LOG(INFO) << "Destroy hccl adapter success."; MS_LOG(INFO) << "Destroy hccl adapter success.";

View File

@ -42,7 +42,7 @@ class HcclAdapter {
static HcclAdapter &GetInstance(); static HcclAdapter &GetInstance();
// common // common
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file, bool is_graph_mode);
bool InitHccl(); bool InitHccl();
bool FinalizeHccl(); bool FinalizeHccl();
@ -121,6 +121,7 @@ class HcclAdapter {
std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr; std::shared_ptr<::ge::OpsKernelBuilder> ops_kernel_builder_ = nullptr;
bool init_flag_ = false; bool init_flag_ = false;
bool is_graph_mode_ = false;
std::mutex init_mutex_; std::mutex init_mutex_;
}; };
} // namespace mindspore::hccl } // namespace mindspore::hccl

View File

@ -78,7 +78,8 @@ class DfGraphConvertor {
void DrawComputeGraph(const std::string &name) { void DrawComputeGraph(const std::string &name) {
std::ofstream fout(name); std::ofstream fout(name);
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(ERROR) << "Open file '" << name << "' failed!"; MS_LOG(ERROR) << "Open file '" << name << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
fout << compute_sout_.str(); fout << compute_sout_.str();
@ -87,7 +88,8 @@ class DfGraphConvertor {
void DrawInitGraph(const std::string &name) { void DrawInitGraph(const std::string &name) {
std::ofstream fout(name); std::ofstream fout(name);
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(ERROR) << "Open file '" << name << "' failed!"; MS_LOG(ERROR) << "Open file '" << name << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
fout << init_sout_.str(); fout << init_sout_.str();
@ -96,7 +98,8 @@ class DfGraphConvertor {
void DrawSaveCheckpointGraph(const std::string &name) { void DrawSaveCheckpointGraph(const std::string &name) {
std::ofstream fout(name); std::ofstream fout(name);
if (!fout.is_open()) { if (!fout.is_open()) {
MS_LOG(ERROR) << "Open file '" << name << "' failed!"; MS_LOG(ERROR) << "Open file '" << name << "' failed!"
<< " Errno:" << errno << " ErrInfo:" << strerror(errno);
return; return;
} }
fout << checkpoint_sout_.str(); fout << checkpoint_sout_.str();

View File

@ -335,6 +335,7 @@ constexpr auto kAttrDataShape = "data_shape";
constexpr auto kAttrFormat = "format"; constexpr auto kAttrFormat = "format";
constexpr auto kAttrReshapeType = "reshape_type"; constexpr auto kAttrReshapeType = "reshape_type";
constexpr auto kAttrAxis = "axis"; constexpr auto kAttrAxis = "axis";
constexpr auto kAttrAxes = "axes";
constexpr auto kAttrKeepDims = "keep_dims"; constexpr auto kAttrKeepDims = "keep_dims";
constexpr auto kAttrShapeGamma = "shape_gamma"; constexpr auto kAttrShapeGamma = "shape_gamma";
constexpr auto kAttrPerm = "perm"; constexpr auto kAttrPerm = "perm";

View File

@ -289,14 +289,6 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
return outputs; return outputs;
} }
void MsBackend::Link(GraphId graph_id) {
MS_EXCEPTION_IF_NULL(target_sess_);
if (graph_id == kInvalidGraphId) {
graph_id = target_sess_->GetFinalRunGraph();
}
target_sess_->BuildGraph(graph_id);
}
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
target_sess_ = session::SessionFactory::Get().Create(target); target_sess_ = session::SessionFactory::Get().Create(target);

View File

@ -61,7 +61,6 @@ class Backend {
virtual bool GetCond(const BaseRef &c, bool *value); virtual bool GetCond(const BaseRef &c, bool *value);
virtual bool GetIndex(const BaseRef &c, int64_t *value); virtual bool GetIndex(const BaseRef &c, int64_t *value);
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; } virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
virtual void Link(GraphId) {}
virtual void SetDebugger() {} virtual void SetDebugger() {}
bool is_multi_graph_sink() const { return is_multi_graph_sink_; } bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
@ -82,7 +81,6 @@ class MsBackend : public Backend {
VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
VectorRef MsSimuRunGraph(const GraphId &g); VectorRef MsSimuRunGraph(const GraphId &g);
void Link(GraphId) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override; GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
VectorRef RunGraph(GraphId graph_id, const VectorRef &args); VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
void ClearSessionGraphs(); void ClearSessionGraphs();

View File

@ -388,6 +388,13 @@ int64_t CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1); MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
AddInst(Instruction::kCall, Ref(fn)); AddInst(Instruction::kCall, Ref(fn));
Ret(static_cast<int64_t>(size - 1)); Ret(static_cast<int64_t>(size - 1));
for (size_t i = size - 1; i > 0; i--) {
const auto iter = slots_.find(inputs[i]);
if (iter != slots_.end() && iter->second >= height_) {
slots_.erase(inputs[i]);
}
}
return RET_SUCCESS; return RET_SUCCESS;
} }
@ -573,9 +580,6 @@ BackendPtr CreateBackend() {
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
backend->set_is_multi_graph_sink(false); backend->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false); context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
} else {
backend->set_is_multi_graph_sink(true);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
} }
} }
return backend; return backend;

View File

@ -759,11 +759,9 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_param_default_value(item.first, cloner[item.second]); new_func_graph->set_param_default_value(item.first, cloner[item.second]);
} }
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
}
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));

View File

@ -54,7 +54,7 @@ from .validators import check_prob, check_crop, check_center_crop, check_resize_
check_uniform_augment_cpp, \ check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \ check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \ check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches check_cut_mix_batch_c, check_posterize, check_gaussian_blur, check_rotate, check_slice_patches, check_adjust_gamma
from ..transforms.c_transforms import TensorOperation from ..transforms.c_transforms import TensorOperation
@ -107,6 +107,37 @@ def parse_padding(padding):
return padding return padding
class AdjustGamma(ImageTensorOperation):
r"""
Apply gamma correction on input image. Input image is expected to be in [..., H, W, C] or [H, W, C] format.
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
gamma (float): Non negative real number.
The output image pixel value is exponentially related to the input image pixel value.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float, optional): The constant multiplier (default=1).
Examples:
>>> transforms_list = [c_vision.Decode(), c_vision.AdjustGamma(gamma=10.0, gain=1.0)]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_adjust_gamma
def __init__(self, gamma, gain=1):
self.gamma = gamma
self.gain = gain
def parse(self):
return cde.AdjustGammaOperation(self.gamma, self.gain)
class AutoContrast(ImageTensorOperation): class AutoContrast(ImageTensorOperation):
""" """
Apply automatic contrast on input image. This operator calculates histogram of image, reassign cutoff percent Apply automatic contrast on input image. This operator calculates histogram of image, reassign cutoff percent

View File

@ -31,7 +31,8 @@ from .validators import check_prob, check_center_crop, check_five_crop, check_re
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \ check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \ check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast, check_rgb_to_bgr, \
check_adjust_gamma
from .utils import Inter, Border from .utils import Inter, Border
from .py_transforms_util import is_pil from .py_transforms_util import is_pil
@ -1375,7 +1376,6 @@ class RgbToBgr:
return util.rgb_to_bgrs(rgb_imgs, self.is_hwc) return util.rgb_to_bgrs(rgb_imgs, self.is_hwc)
class RgbToHsv: class RgbToHsv:
""" """
Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images. Convert a NumPy RGB image or a batch of NumPy RGB images to HSV images.
@ -1525,6 +1525,44 @@ class RandomSharpness:
return util.random_sharpness(img, self.degrees) return util.random_sharpness(img, self.degrees)
class AdjustGamma:
"""
Adjust gamma of the input PIL image.
Args:
gamma (float): Non negative real number, same as gamma in the equation.
gain (float, optional): The constant multiplier.
Examples:
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>> transforms_list = Compose([py_vision.Decode(),
... py_vision.AdjustGamma(),
... py_vision.ToTensor()])
>>> # apply the transform to dataset through map function
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns="image")
"""
@check_adjust_gamma
def __init__(self, gamma, gain=1.0):
self.gamma = gamma
self.gain = gain
self.random = False
def __call__(self, img):
"""
Call method.
Args:
img (PIL image): Image to be augmented with AutoContrast.
Returns:
img (PIL image), Augmented image.
"""
return util.adjust_gamma(img, self.gamma, self.gain)
class AutoContrast: class AutoContrast:
""" """
Automatically maximize the contrast of the input PIL image. Automatically maximize the contrast of the input PIL image.

View File

@ -19,7 +19,6 @@ import math
import numbers import numbers
import random import random
import colorsys import colorsys
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageEnhance, __version__ from PIL import Image, ImageOps, ImageEnhance, __version__
@ -1243,6 +1242,7 @@ def rgb_to_bgr(np_rgb_img, is_hwc):
np_bgr_img = np_rgb_img[::-1, :, :] np_bgr_img = np_rgb_img[::-1, :, :]
return np_bgr_img return np_bgr_img
def rgb_to_bgrs(np_rgb_imgs, is_hwc): def rgb_to_bgrs(np_rgb_imgs, is_hwc):
""" """
Convert RGB imgs to BGR imgs. Convert RGB imgs to BGR imgs.
@ -1473,6 +1473,32 @@ def random_sharpness(img, degrees):
return ImageEnhance.Sharpness(img).enhance(v) return ImageEnhance.Sharpness(img).enhance(v)
def adjust_gamma(img, gamma, gain):
"""
Adjust gamma of the input PIL image.
Args:
img (PIL image): Image to be augmented with AdjustGamma.
gamma (float): Non negative real number, same as gamma in the equation.
gain (float, optional): The constant multiplier.
Returns:
img (PIL image), Augmented image.
"""
if not is_pil(img):
raise TypeError("img should be PIL image. Got {}.".format(type(img)))
gamma_table = [(255 + 1 - 1e-3) * gain * pow(x / 255., gamma) for x in range(256)]
if len(img.split()) == 3:
gamma_table = gamma_table * 3
img = img.point(gamma_table)
elif len(img.split()) == 1:
img = img.point(gamma_table)
return img
def auto_contrast(img, cutoff, ignore): def auto_contrast(img, cutoff, ignore):
""" """
Automatically maximize the contrast of the input PIL image. Automatically maximize the contrast of the input PIL image.

View File

@ -19,10 +19,10 @@ from functools import wraps
import numpy as np import numpy as np
from mindspore._c_dataengine import TensorOp, TensorOperation from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, parse_user_args, type_check, \ check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff, check_value_ratio, \ parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
check_odd check_value_cutoff, check_value_ratio, check_odd
from .utils import Inter, Border, ImageBatchFormat, SliceMode from .utils import Inter, Border, ImageBatchFormat, SliceMode
@ -788,6 +788,22 @@ def check_bounding_box_augment_cpp(method):
return new_method return new_method
def check_adjust_gamma(method):
"""Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[gamma, gain], _ = parse_user_args(method, *args, **kwargs)
type_check(gamma, (float, int), "gamma")
check_value(gamma, (0, FLOAT_MAX_INTEGER))
if gain is not None:
type_check(gain, (float, int), "gain")
check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER))
return method(self, *args, **kwargs)
return new_method
def check_auto_contrast(method): def check_auto_contrast(method):
"""Wrapper method to check the parameters of AutoContrast ops (Python and C++).""" """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""

View File

@ -226,13 +226,13 @@ build_lite() {
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
cd ${BASEPATH}/../ cd ${BASEPATH}/../
if [[ "${local_lite_platform}" == "x86_64" ]]; then if [[ "${local_lite_platform}" == "x86_64" ]]; then
sh ${compile_nnie_script} -I x86_64 -b nnie_master -j $THREAD_NUM sh ${compile_nnie_script} -I x86_64 -b nnie_r1.4 -j $THREAD_NUM
if [[ $? -ne 0 ]]; then if [[ $? -ne 0 ]]; then
echo "compile x86_64 for nnie failed." echo "compile x86_64 for nnie failed."
exit 1 exit 1
fi fi
elif [[ "${local_lite_platform}" == "arm32" ]]; then elif [[ "${local_lite_platform}" == "arm32" ]]; then
sh ${compile_nnie_script} -I arm32 -b nnie_master -j $THREAD_NUM sh ${compile_nnie_script} -I arm32 -b nnie_r1.4 -j $THREAD_NUM
if [[ $? -ne 0 ]]; then if [[ $? -ne 0 ]]; then
echo "compile arm32 for nnie failed." echo "compile arm32 for nnie failed."
exit 1 exit 1

View File

@ -133,6 +133,7 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../tools/common/storage.cc
) )
if(ENABLE_V0) if(ENABLE_V0)

View File

@ -209,8 +209,11 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
MS_LOG(ERROR) << "Graph output shouldn't have data"; MS_LOG(ERROR) << "Graph output shouldn't have data";
return RET_ERROR; return RET_ERROR;
} }
// a tensor is as both input and output, would be treated as an input.
if (!dst_tensor->IsGraphInput()) {
dst_tensor->set_category(Tensor::GRAPH_OUTPUT); dst_tensor->set_category(Tensor::GRAPH_OUTPUT);
} }
}
if (src_tensor->name() != nullptr) { if (src_tensor->name() != nullptr) {
dst_tensor->set_tensor_name(src_tensor->name()->str()); dst_tensor->set_tensor_name(src_tensor->name()->str());
} }
@ -364,16 +367,13 @@ void LiteSession::InitGraphInOutTensorsMap(const lite::Model *model) {
InitGraphInputMap(model); InitGraphInputMap(model);
InitGraphOutputNodeMap(model); InitGraphOutputNodeMap(model);
InitGraphOutputTensorMap(model); InitGraphOutputTensorMap(model);
for (auto *tensor : this->inputs_) {
tensor->set_category(Tensor::Category::GRAPH_INPUT);
}
for (auto *tensor : this->outputs_) {
tensor->set_category(Tensor::Category::GRAPH_OUTPUT);
}
} }
void LiteSession::IsolateOutputTensor() { void LiteSession::IsolateOutputTensor() {
for (Tensor *src_tensor : outputs_) { for (Tensor *src_tensor : outputs_) {
if (src_tensor->IsGraphInput()) {
continue;
}
Tensor *new_tensor = Tensor *new_tensor =
new Tensor(src_tensor->data_type(), src_tensor->shape(), src_tensor->format(), Tensor::GRAPH_OUTPUT); new Tensor(src_tensor->data_type(), src_tensor->shape(), src_tensor->format(), Tensor::GRAPH_OUTPUT);
new_tensor->set_allocator(src_tensor->allocator()); /* GPU use opencl allocator */ new_tensor->set_allocator(src_tensor->allocator()); /* GPU use opencl allocator */

View File

@ -24,16 +24,24 @@
namespace mindspore::lite { namespace mindspore::lite {
void MindrtExecutor::PrepareInputData(const std::vector<kernel::LiteKernel *> &kernels, void MindrtExecutor::PrepareInputData(const std::vector<kernel::LiteKernel *> &kernels,
const std::vector<Tensor *> &inputs) { const std::vector<Tensor *> &inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
for (size_t j = 0; j < kernels.size(); ++j) { for (size_t j = 0; j < kernels.size(); ++j) {
auto in_tensor_size = kernels[j]->in_tensors().size(); auto in_tensor_size = kernels[j]->in_tensors().size();
for (size_t k = 0; k < in_tensor_size; ++k) { for (size_t k = 0; k < in_tensor_size; ++k) {
if (inputs[i] != kernels[j]->in_tensors()[k]) { auto tensor = kernels[j]->in_tensors()[k];
if (!tensor->IsGraphInput()) {
continue; continue;
} }
auto data = std::make_shared<OpData<Tensor>>(op_actors_[j]->GetAID(), inputs[i], static_cast<int>(k)); size_t idx = std::find(inputs.begin(), inputs.end(), tensor) - inputs.begin();
input_data_.emplace_back(data); if (idx == inputs.size()) {
MS_LOG(ERROR) << "The input is not found.";
return;
} }
auto data = std::make_shared<OpData<Tensor>>(op_actors_[j]->GetAID(), inputs.at(idx), static_cast<int>(k));
if (data == nullptr) {
MS_LOG(ERROR) << "new opdata failed.";
return;
}
input_data_.emplace_back(data);
} }
} }
} }
@ -42,6 +50,9 @@ void MindrtExecutor::PrepareOutputData(const std::vector<kernel::LiteKernel *> &
const std::vector<Tensor *> &outputs) { const std::vector<Tensor *> &outputs) {
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
Tensor *graph_output_tensor = outputs[i]; Tensor *graph_output_tensor = outputs[i];
if (graph_output_tensor->IsGraphInput()) {
continue;
}
auto current_output_map = auto current_output_map =
std::find_if(output_tensor_map_->begin(), output_tensor_map_->end(), [&](const auto output_map_tensor) { std::find_if(output_tensor_map_->begin(), output_tensor_map_->end(), [&](const auto output_map_tensor) {
if (graph_output_tensor == output_map_tensor.second) { if (graph_output_tensor == output_map_tensor.second) {

View File

@ -74,6 +74,10 @@ void GroupConvolutionBaseCPUKernel::FreeSubKernel() {
sub_conv = nullptr; sub_conv = nullptr;
} }
group_convs_.clear(); group_convs_.clear();
if (group_conv_creator_ != nullptr) {
delete group_conv_creator_;
group_conv_creator_ = nullptr;
}
} }
int GroupConvolutionBaseCPUKernel::PreProcess() { int GroupConvolutionBaseCPUKernel::PreProcess() {

View File

@ -26,7 +26,8 @@ using mindspore::schema::PrimitiveType_PadFusion;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
constexpr size_t kPadMaxInputSize = 2; constexpr size_t kPadCommonInputSize = 2;
constexpr size_t kPadMaxInputSize = 3;
} // namespace } // namespace
int PadFp16CPUKernel::RunImpl(int task_id) { int PadFp16CPUKernel::RunImpl(int task_id) {
PadFp16(input_, output_, in_, out_, pad_param_->paddings_, task_id, op_parameter_->thread_num_); PadFp16(input_, output_, in_, out_, pad_param_->paddings_, task_id, op_parameter_->thread_num_);
@ -53,8 +54,14 @@ int PadFp16CPUKernel::RunMirrorPadImpl(int task_id) {
for (int b = 0; b < block.size_[1]; b++) { for (int b = 0; b < block.size_[1]; b++) {
int out_b_index = out_a_index + b * block.out_stride_[1]; int out_b_index = out_a_index + b * block.out_stride_[1];
for (int c = 0; c < block.size_[2]; ++c) { for (int c = 0; c < block.size_[2]; ++c) {
int output_index = out_b_index + c * block.out_stride_[2]; int out_c_index = out_b_index + c * block.out_stride_[2];
MirrorPadFp16(input_data, output_data, in_, pad_param_, output_index, output_index + block.size_[3]); for (int d = 0; d < block.size_[3]; ++d) {
int out_d_index = out_c_index + d * block.out_stride_[3];
for (int e = 0; e < block.size_[4]; ++e) {
int output_index = out_d_index + e * block.out_stride_[4];
MirrorPadFp16(input_data, output_data, in_, pad_param_, output_index, output_index + block.size_[5]);
}
}
} }
} }
} }
@ -88,13 +95,16 @@ int PadFp16CPUKernel::Run() {
MS_ASSERT(output_ != nullptr); MS_ASSERT(output_ != nullptr);
int ret = 0; int ret = 0;
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) { if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
if (in_tensors_.size() == kPadMaxInputSize) { if (in_tensors_.size() >= kPadCommonInputSize) {
ret = CopyPaddingFromInput(); ret = CopyPaddingFromInput();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "PadFp16CPUKernel CopyPaddingFromInput failed"; MS_LOG(ERROR) << "PadFp16CPUKernel CopyPaddingFromInput failed";
return RET_ERROR; return RET_ERROR;
} }
} }
if (in_tensors_.size() == kPadMaxInputSize) {
pad_param_->constant_value_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c())[0];
}
if (pad_param_->constant_value_ - 0.0f < 1e-5) { if (pad_param_->constant_value_ - 0.0f < 1e-5) {
memset(output_, 0, output_tensor->ElementsNum() * sizeof(float16_t)); memset(output_, 0, output_tensor->ElementsNum() * sizeof(float16_t));
} else { } else {

View File

@ -28,7 +28,8 @@ using mindspore::schema::PrimitiveType_PadFusion;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace { namespace {
constexpr size_t kMirrorPadInputSize = 2; constexpr size_t kMirrorPadInputSize = 2;
constexpr size_t kPadMaxInputSize = 2; constexpr size_t kPadCommonInputSize = 2;
constexpr size_t kPadMaxInputSize = 3;
} // namespace } // namespace
int PadCPUKernel::Init() { int PadCPUKernel::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
@ -40,30 +41,30 @@ int PadCPUKernel::Init() {
int PadCPUKernel::ReSize() { int PadCPUKernel::ReSize() {
auto input = in_tensors_.at(0); auto input = in_tensors_.at(0);
auto rank = input->shape().size(); auto rank = input->shape().size();
if (rank > COMM_SHAPE_SIZE) { if (rank > DEFAULT_PAD_NDIMS) {
MS_LOG(ERROR) << "Pad input rank should <= " << COMM_SHAPE_SIZE << ", got " << rank; MS_LOG(ERROR) << "Pad input rank should <= " << DEFAULT_PAD_NDIMS << ", got " << rank;
return RET_ERROR; return RET_ERROR;
} }
auto output = out_tensors_.at(0); auto output = out_tensors_.at(0);
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) { if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
auto ret = ExtendShape(in_, COMM_SHAPE_SIZE, input->shape().data(), rank); auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input->shape().data(), rank);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
ret = ExtendShape(out_, COMM_SHAPE_SIZE, output->shape().data(), rank); ret = ExtendShape(out_, DEFAULT_PAD_NDIMS, output->shape().data(), rank);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
if (pad_param_->padding_length < MAX_SHAPE_SIZE) { if (pad_param_->padding_length < MAX_PAD_SIZE) {
int ori_paddings[MAX_SHAPE_SIZE]; int ori_paddings[MAX_PAD_SIZE];
for (auto i = 0; i < pad_param_->padding_length; ++i) { for (auto i = 0; i < pad_param_->padding_length; ++i) {
ori_paddings[i] = pad_param_->paddings_[i]; ori_paddings[i] = pad_param_->paddings_[i];
} }
ret = ExtendPaddings(pad_param_->paddings_, MAX_SHAPE_SIZE, ori_paddings, pad_param_->padding_length); ret = ExtendPaddings(pad_param_->paddings_, MAX_PAD_SIZE, ori_paddings, pad_param_->padding_length);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
pad_param_->padding_length = MAX_SHAPE_SIZE; pad_param_->padding_length = MAX_PAD_SIZE;
} }
} }
return RET_OK; return RET_OK;
@ -71,8 +72,8 @@ int PadCPUKernel::ReSize() {
void PadCPUKernel::InitMirrorPadBlock() { void PadCPUKernel::InitMirrorPadBlock() {
mirror_pad_block_.clear(); mirror_pad_block_.clear();
std::vector<int> left_pads(COMM_SHAPE_SIZE); std::vector<int> left_pads(DEFAULT_PAD_NDIMS);
for (size_t i = 0; i < COMM_SHAPE_SIZE; ++i) { for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
left_pads[i] = pad_param_->paddings_[2 * i]; left_pads[i] = pad_param_->paddings_[2 * i];
} }
@ -83,7 +84,7 @@ void PadCPUKernel::InitMirrorPadBlock() {
/* init separate dims */ /* init separate dims */
int cur_input = 1; int cur_input = 1;
int cur_output = 1; int cur_output = 1;
for (size_t i = 0; i < COMM_SHAPE_SIZE; ++i) { for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
if (cur_input > 1) { if (cur_input > 1) {
input_separate_dims.emplace_back(cur_input); input_separate_dims.emplace_back(cur_input);
output_separate_dims.emplace_back(cur_output); output_separate_dims.emplace_back(cur_output);
@ -149,7 +150,7 @@ void PadCPUKernel::InitMirrorPadBlock() {
} }
MirrorPadBlock block; MirrorPadBlock block;
const int size_offset = COMM_SHAPE_SIZE - static_cast<int>(pad_region.size()); const int size_offset = DEFAULT_PAD_NDIMS - static_cast<int>(pad_region.size());
for (size_t i = 0; i < pad_region.size(); ++i) { for (size_t i = 0; i < pad_region.size(); ++i) {
int di = size_offset + i; int di = size_offset + i;
int si = remain_dim_offset + i; int si = remain_dim_offset + i;
@ -265,8 +266,14 @@ int PadCPUKernel::RunMirrorPadImpl(int task_id) {
for (int b = 0; b < block.size_[1]; b++) { for (int b = 0; b < block.size_[1]; b++) {
int out_b_index = out_a_index + b * block.out_stride_[1]; int out_b_index = out_a_index + b * block.out_stride_[1];
for (int c = 0; c < block.size_[2]; ++c) { for (int c = 0; c < block.size_[2]; ++c) {
int output_index = out_b_index + c * block.out_stride_[2]; int out_c_index = out_b_index + c * block.out_stride_[2];
MirrorPad(input_data, output_data, in_, pad_param_, output_index, output_index + block.size_[3]); for (int d = 0; d < block.size_[3]; ++d) {
int out_d_index = out_c_index + d * block.out_stride_[3];
for (int e = 0; e < block.size_[4]; ++e) {
int output_index = out_d_index + e * block.out_stride_[4];
MirrorPad(input_data, output_data, in_, pad_param_, output_index, output_index + block.size_[5]);
}
}
} }
} }
} }
@ -310,8 +317,8 @@ int PadCPUKernel::CheckPaddings(int *paddings, int length, int *input_shape, int
} }
int PadCPUKernel::CopyPaddingFromInput() { int PadCPUKernel::CopyPaddingFromInput() {
if (in_tensors_.size() != kMirrorPadInputSize) { if (in_tensors_.size() < kMirrorPadInputSize) {
MS_LOG(ERROR) << "Pad Reflect or Symmetric mode need 2 inputs, got " << in_tensors_.size(); MS_LOG(ERROR) << "Pad Reflect or Symmetric mode need at least 2 inputs, got " << in_tensors_.size();
return RET_ERROR; return RET_ERROR;
} }
auto padding_tensor = in_tensors_.at(1); auto padding_tensor = in_tensors_.at(1);
@ -327,28 +334,28 @@ int PadCPUKernel::CopyPaddingFromInput() {
return RET_ERROR; return RET_ERROR;
} }
auto ret = ExtendShape(in_, COMM_SHAPE_SIZE, input_shape.data(), rank); auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
ret = ExtendPaddings(pad_param_->paddings_, MAX_SHAPE_SIZE, paddings, padding_tensor->ElementsNum()); ret = ExtendPaddings(pad_param_->paddings_, MAX_PAD_SIZE, paddings, padding_tensor->ElementsNum());
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
pad_param_->padding_length = MAX_SHAPE_SIZE; pad_param_->padding_length = MAX_PAD_SIZE;
return RET_OK; return RET_OK;
} }
void PadCPUKernel::CalculateStrides() { void PadCPUKernel::CalculateStrides() {
pad_param_->in_strides[COMM_SHAPE_SIZE - 1] = 1; pad_param_->in_strides[DEFAULT_PAD_NDIMS - 1] = 1;
for (auto i = COMM_SHAPE_SIZE - 2; i >= 0; --i) { for (auto i = DEFAULT_PAD_NDIMS - 2; i >= 0; --i) {
pad_param_->in_strides[i] = in_[i + 1] * pad_param_->in_strides[i + 1]; pad_param_->in_strides[i] = in_[i + 1] * pad_param_->in_strides[i + 1];
} }
for (auto i = 0; i < COMM_SHAPE_SIZE; ++i) { for (auto i = 0; i < DEFAULT_PAD_NDIMS; ++i) {
out_[i] = in_[i] + pad_param_->paddings_[i * 2] + pad_param_->paddings_[i * 2 + 1]; out_[i] = in_[i] + pad_param_->paddings_[i * 2] + pad_param_->paddings_[i * 2 + 1];
} }
pad_param_->out_strides[COMM_SHAPE_SIZE - 1] = 1; pad_param_->out_strides[DEFAULT_PAD_NDIMS - 1] = 1;
for (auto i = COMM_SHAPE_SIZE - 2; i >= 0; --i) { for (auto i = DEFAULT_PAD_NDIMS - 2; i >= 0; --i) {
pad_param_->out_strides[i] = out_[i + 1] * pad_param_->out_strides[i + 1]; pad_param_->out_strides[i] = out_[i + 1] * pad_param_->out_strides[i + 1];
} }
} }
@ -358,7 +365,7 @@ int PadCPUKernel::HandleMirrorPad() {
if (in_tensors_.size() == 1) { if (in_tensors_.size() == 1) {
auto input_shape = in_tensors_.at(0)->shape(); auto input_shape = in_tensors_.at(0)->shape();
int rank = static_cast<int>(input_shape.size()); int rank = static_cast<int>(input_shape.size());
ret = ExtendShape(in_, COMM_SHAPE_SIZE, input_shape.data(), rank); ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
@ -368,7 +375,7 @@ int PadCPUKernel::HandleMirrorPad() {
return ret; return ret;
} }
} }
ret = CheckPaddings(pad_param_->paddings_, COMM_SHAPE_SIZE, in_, pad_param_->pad_mode_); ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;
} }
@ -391,13 +398,16 @@ int PadCPUKernel::Run() {
} }
int error_code = 0; int error_code = 0;
if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) { if (pad_param_->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
if (in_tensors_.size() == kPadMaxInputSize) { if (in_tensors_.size() >= kPadCommonInputSize) {
error_code = CopyPaddingFromInput(); error_code = CopyPaddingFromInput();
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]"; MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]";
return RET_ERROR; return RET_ERROR;
} }
} }
if (in_tensors_.size() == kPadMaxInputSize) {
pad_param_->constant_value_ = reinterpret_cast<float *>(in_tensors_.at(2)->data_c())[0];
}
auto output = out_tensors_.at(0); auto output = out_tensors_.at(0);
int output_size = output->ElementsNum(); int output_size = output->ElementsNum();
auto output_data = reinterpret_cast<float *>(output->data_c()); auto output_data = reinterpret_cast<float *>(output->data_c());

View File

@ -55,8 +55,8 @@ class PadCPUKernel : public InnerKernel {
int HandleMirrorPad(); int HandleMirrorPad();
int CopyPaddingFromInput(); int CopyPaddingFromInput();
PadParameter *pad_param_ = nullptr; PadParameter *pad_param_ = nullptr;
int in_[4] = {0}; int in_[DEFAULT_PAD_NDIMS] = {0};
int out_[4] = {0}; int out_[DEFAULT_PAD_NDIMS] = {0};
std::vector<MirrorPadBlock> mirror_pad_block_; std::vector<MirrorPadBlock> mirror_pad_block_;
}; };

View File

@ -316,8 +316,10 @@ void Tensor::FreeData() {
this->data_ = nullptr; this->data_ = nullptr;
} else { } else {
allocator_->Free(this->data_); allocator_->Free(this->data_);
if (!IS_STATIC_ALLOCATOR(allocator_) || (allocator_->RefCount(this->data_) != 0)) {
this->data_ = nullptr; this->data_ = nullptr;
} }
}
} }
void *Tensor::ReallocData() { void *Tensor::ReallocData() {

View File

@ -34,12 +34,15 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
#define STATIC_ALLOCATION -271964
#define IS_STATIC_ALLOCATOR(allocator) ((allocator != nullptr) && (allocator->RefCount(nullptr) == STATIC_ALLOCATION))
struct LiteQuantParam { struct LiteQuantParam {
double scale; double scale;
int32_t zeroPoint; int32_t zeroPoint;
float var_corr{1}; float var_corr{1};
float mean_corr{0}; float mean_corr{0};
bool inited; bool inited{false};
std::vector<float> clusters{}; std::vector<float> clusters{};
int bitNum; int bitNum;
int roundType; int roundType;
@ -133,7 +136,6 @@ class Tensor : public mindspore::tensor::MSTensor {
void set_format(mindspore::Format format) override { this->format_ = format; } void set_format(mindspore::Format format) override { this->format_ = format; }
mindspore::Format format() const override { return this->format_; } mindspore::Format format() const override { return this->format_; }
virtual int ref_count() const { return ref_count_; } virtual int ref_count() const { return ref_count_; }
virtual int init_ref_count() const { return this->init_ref_count_; } virtual int init_ref_count() const { return this->init_ref_count_; }

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/train/opt_allocator.h"
#include <limits>
#include "nnacl/op_base.h"
namespace mindspore {
size_t OptAllocator::FindFree(size_t size) {
size_t min_size = std::numeric_limits<size_t>::max();
size_t min_addr = std::numeric_limits<size_t>::max();
for (auto const &itr : arena_) {
// best fit
if (itr.second >= size) {
if (min_size > itr.second) {
min_size = itr.second;
min_addr = itr.first;
}
}
}
return min_addr;
}
void OptAllocator::Reorder(size_t addr) {
size_t length = arena_[addr];
size_t post = addr + length;
// connect to upper block
auto it = arena_.find(post);
if (it != arena_.end()) {
size_t post_size = it->second;
arena_[addr] = length + post_size;
arena_.erase(post);
}
// connect to lower block
auto itr = arena_.lower_bound(addr);
if (itr != arena_.begin()) {
itr--;
size_t last = itr->first;
if ((last + arena_[last]) == addr) {
arena_[last] = arena_[last] + arena_[addr];
arena_.erase(addr);
}
}
}
size_t OptAllocator::Malloc(size_t size) {
size = UP_DIV(size, align_size_) * align_size_;
size_t addr = FindFree(size);
// free block not found
if (addr == std::numeric_limits<size_t>::max()) {
if (!arena_.empty()) {
addr = arena_.rbegin()->first;
if (addr + arena_[addr] < heap_) {
addr = heap_;
} else {
arena_.erase(addr);
}
} else {
addr = heap_;
}
heap_ = addr + size;
} else {
if (arena_[addr] > size) {
arena_[addr + size] = arena_[addr] - size;
}
arena_.erase(addr);
}
alloc_[addr] = size;
return addr;
}
void OptAllocator::Free(size_t addr) {
arena_[addr] = alloc_[addr];
alloc_.erase(addr);
Reorder(addr);
}
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More