!242 sync from mindspore to incubator 0623

Merge pull request !242 from yanghaoran/code_sync_0623
This commit is contained in:
mindspore-ci-bot 2020-06-24 11:35:45 +08:00 committed by Gitee
commit db5df9db60
1924 changed files with 144381 additions and 22002 deletions

3
.gitmodules vendored
View File

@ -10,6 +10,9 @@
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf.git
[submodule "akg"]
path = akg
url = https://gitee.com/mindspore/akg.git
[submodule "graphengine"]
path = graphengine
url = https://gitee.com/ms-incubator/graphengine.git

View File

@ -7,7 +7,7 @@ endif ()
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
if (ENABLE_GE)
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
endif ()
@ -86,10 +86,18 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
endif()
if (ENABLE_AKG AND ENABLE_D)
add_subdirectory("${CMAKE_SOURCE_DIR}/akg")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
add_subdirectory(mindspore/ccsrc)
if (ENABLE_TESTCASES)
add_subdirectory(tests)
endif()
include(cmake/package.cmake)
if (ENABLE_SERVING)
add_subdirectory(serving)
endif()
include(cmake/package.cmake)

View File

@ -7,6 +7,7 @@
* DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset.
* DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark.
* Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset.
* SSD: a single stage object detection methods on COCO 2017 dataset.
* GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset.
* Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset.
* Frontend and User Interface
@ -62,7 +63,7 @@
## Contributors
Thanks goes to these wonderful people:
Alexey Shevlyakov, Amir Lashkari, anthony, baihuawei, biffex, buxue, caifubi, candanzg, caojian05, Cathy Wong, changzherui, chenfei, chengxianbin, chenhaozhe, chenzomi, chujinjin, cristoval, dengwentao, eric, etone-chan, fary86, gaojing, gengdongjie, gongchen, guohongzilong, guozhijian, heleiwang, hesham, He Wei, Hoai Linh Tran h00472437, hongxing, huangdongrun, huanghui, Jamie Nisbet, Jesse Lee, jiangjinsheng, jiangzhiwen, jinyaohui, jjfeing, jonwe, jonyguo, Junhan Hu, Kang, kingfo, kswang, laiyongqiang, leopz, lichenever, lihongkang, limingqi107, liubuyu, liuliyan2, liuwenhao4, liuxiao, liuxiao, liyong, lizhenyu, lvliang, Margaret_wangrui, meixiaowei, ms_yan, Nat Sutyanyong, ougongchang, panfengfeng, panyifeng, Peilin Wang, peixu_ren, qianlong, rick_sanchez, seatea, sheng, shijianning, simson, sunsuodong, Tinazhang, VectorSL, wandongdong, wangcong, wanghua, wangnan39, Wei Luning, wenchunjiang, wilfChen, WilliamLian, wsc, wukesong, wuxuejian, Xiaoda Zhang, xiefangqi, xulei2020, Yang, yangjie159, yangruoqi713, yangyongjie, yangzhenzhang, Yanjun Peng, yanzhenxiang2020, yao_yf, Yi Huaijie, yoonlee666, yujianfeng, YuJianfeng, yvetteliu, z00478463, zhangdengcheng, Zhang Qinghua, zhangz0911gm, zhaojichen, zhaoting, zhaozhenlong, zhoufeng, zhouneng, zhousiyi, zhouyuanshen, Zirui Wu, Ziyan, zjun, ZPaC, lihongzhang
Alexey Shevlyakov, Amir Lashkari, anthony, baihuawei, biffex, buxue, caifubi, candanzg, caojian05, Cathy Wong, changzherui, chenfei, chengxianbin, chenhaozhe, chenzomi, chujinjin, cristoval, dengwentao, eric, etone-chan, fary86, gaojing, gengdongjie, gongchen, guohongzilong, guozhijian, heleiwang, hesham, He Wei, Hoai Linh Tran, hongxing, huangdongrun, huanghui, Jamie Nisbet, Jesse Lee, jiangjinsheng, jiangzhiwen, jinyaohui, jjfeing, jonwe, jonyguo, Junhan Hu, Kang, kingfo, kswang, laiyongqiang, leopz, lichenever, lihongkang, limingqi107, liubuyu, liuliyan2, liuwenhao4, liuxiao, liuxiao, liyong, lizhenyu, lvliang, Margaret_wangrui, meixiaowei, ms_yan, Nat Sutyanyong, ougongchang, panfengfeng, panyifeng, Peilin Wang, peixu_ren, qianlong, rick_sanchez, seatea, sheng, shijianning, simson, sunsuodong, Tinazhang, VectorSL, wandongdong, wangcong, wanghua, wangnan39, Wei Luning, wenchunjiang, wilfChen, WilliamLian, wsc, wukesong, wuxuejian, Xiaoda Zhang, xiefangqi, xulei2020, Yang, yangjie159, yangruoqi713, yangyongjie, yangzhenzhang, Yanjun Peng, yanzhenxiang2020, yao_yf, Yi Huaijie, yoonlee666, yujianfeng, YuJianfeng, yvetteliu, zhangdengcheng, Zhang Qinghua, zhangz0911gm, zhaojichen, zhaoting, zhaozhenlong, zhoufeng, zhouneng, zhousiyi, zhouyuanshen, Zirui Wu, Ziyan, zjun, ZPaC, lihongzhang
Contributions of any kind are welcome!

View File

@ -2245,14 +2245,14 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Please also refer to the file CONTRIBUTING.md, which clarifies licensing of
external contributions to this project including patches, pull requests, etc.
Software: SQLite 3.31.1
Software: SQLite 3.32.2
Copyright notice:
Copyright 2008 D. Richard Hipp and Hipp, Wyrick & Company, Inc.
Copyright (C) 1996, 1997, 1998, 1999, 2000, 2001, 2003, 2004, 2005, 2006, 2007 2008 Free Software Foundation, Inc.
(c) The page number is greater than the largest page that existed in
Copyright (c) 1991-2011 Unicode, Inc.
Copyright 2008 D. Richard Hipp and Hipp, Wyrick & Company, Inc.
Copyright (c) 2002 by David Gravereaux.
Copyright (c) 2006 by Pat Thoyts
(c) The page number is greater than the largest page that existed in
Copyright (C) 1996, 1997, 1998, 1999, 2000, 2001, 2003, 2004, 2005, 2006, 2007 2008 Free Software Foundation, Inc.
License: Public Domain
Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means.
@ -3053,6 +3053,646 @@ Copyright 2003 Google Inc.
Copyright 2009 Google Inc.
Copyright (C) 1991-2, RSA Data Security, Inc. Created 1991. All
Software: tinyxml2 8.0.0
Copyright 2011, John Resig.
Copyright 2011, The Dojo Foundation.
Software: icu 67.1
Copyright (C) 2000-2004, International Business Machines Corporation
Copyright (C) 2002-2014, International Business Machines(C) Copyright IBM Corp. 1998-2011 - All Rights Reserved
Copyright (C) 2003-2008, International Business Machines
Copyright (C) 2005-2006, International Business Machines
Copyright (C) 2016 and later: Unicode, Inc. and others.
Copyright (c) 2001-2010 International Business Machines
Copyright (C) 2009, International Business Machines
Copyright (c) 2010-2015 International Business Machines Corporation and others. All rights reserved.
Copyright (C) 2002-2015, International Business Machines verbatim (minus copyright and #include) and copied together into this file.
Copyright (c) 1997-2014, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 1997-2008, International Business Machines Corporation and
Copyright (c) 1997-2003, International Business Machines Corporation and
Copyright (c) 1996-2012, International Business Machines Corporation and
Copyright (c) 1997-2016, International Business Machines
Copyright (c) 1997-2013 International Business Machines
Copyright (c) 1997-2016, International Business Machines Corporation and
Copyright (c) 1997-2001, International Business Machines Corporation and
Copyright (c) 1997-2012, International Business Machines Corporation and
Copyright (c) 1997-2005, International Business Machines Corporation and
Copyright (c) 1997-2010, International Business Machines Corporation and
Copyright (c) 2011-2016, International Business Machines Corporation
Copyright (c) 1997-2009, International Business Machines Corporation and
Copyright (c) 1997-2002,2008, International Business Machines Corporation and
Copyright (c) 1997-2009,2014, International Business Machines
Copyright (C) 2000-2009, International Business Machines
Copyright (c) 1997-2015, International Business Machines Corporation and
Copyright (c) 1997-2013, International Business Machines Corporation and
Copyright (c) 2001-2016, International Business Machines Corporation and
Copyright (c) 1997-2016, International Business Machines Corporation
Copyright (c) 1997-2003, 2007-2009 International Business Machines Corporation and
Copyright (c) 2011-2014, International Business Machines Corporation
Copyright (c) 2003-2009, International Business Machines
Copyright (c) 2016, International Business Machines Corporation
Copyright (c) 1997-2004, International Business Machines Corporation and
Copyright (C) 2002-2016, International Business Machines
Copyright (C) 1998-2014, International Business Machines Corporation
Copyright (c) 2003-2013, International Business Machines Corporation and
Copyright (c) 2005-2016, International Business Machines Corporation and
Copyright (c) 1999-2013, International Business Machines Corporation and
Copyright (c) 2003-2015, International Business Machines Corporation and
Copyright (C) 2003-2016, International Business Machines
Copyright (C) 2003-2014, International Business Machines
Copyright (C) 2003, International Business Machines
Copyright (c) 1998-2016, International Business Machines Corporation and
Copyright (c) 2004-2015, International Business Machines Corporation and
Copyright (c) 2009-2016, International Business Machines Corporation and
Copyright (C) 2003-2012, International Business Machines
Copyright (c) 2000-2016, International Business Machines Corporation and
Copyright (C) 2001-2014, International Business Machines
Copyright (C) 2001-2016, International Business Machines
Copyright (c) 1997-2014, International Business Machines © 2017 and later: Unicode, Inc. and others.
Copyright (C) 2007-2016, International Business Machines © 2018 and later: Unicode, Inc. and others.
Copyright (c) 2015, International Business Machines Corporation
Copyright (c) 2014-2016, International Business Machines Corporation
Copyright (c) 2002-2016, International Business Machines
Copyright (c) 2001-2011,2015 International Business Machines
Copyright (c) 2001-2016 International Business Machines
Copyright (c) 2005-2013, International Business Machines Corporation and
Copyright (c) 1998-2014, International Business Machines Corporation and
Copyright (C) 1997-2016 International Business Machines
Copyright (C) 2009-2014, International Business Machines Corporation and
Copyright (c) 2002-2014, International Business Machines Corporation
Copyright (c) 2002-2007, International Business Machines Corporation
Copyright (C) 1996-2012, International Business Machines Corporation
Copyright (C) 1996-2008, International Business Machines Corporation
Copyright (C) 2007-2013, International Business Machines Corporation and
Copyright (C) 2008-2015, International Business Machines
Copyright (C) 2003-2013, International Business Machines Corporation and
Copyright (C) 2003-2013, International Business Machines Corporation
Copyright (C) 1997-2016, International Business Machines Corporation and
Copyright (C) 2001-2011, International Business Machines
Copyright (C) 2001-2008, International Business Machines
Copyright (C) 2003 - 2009, International Business Machines Corporation and
Copyright (C) 2003 - 2008, International Business Machines Corporation and
Copyright (C) 2007-2014, International Business Machines Corporation
Copyright (C) 2007-2013, International Business Machines Corporation
Copyright (C) 1997-2013, International Business Machines Corporation and
Copyright (C) 1996-2014, International Business Machines Corporation and
Copyright (C) 2010-2014, International Business Machines
Copyright (C) 2010-2015, International Business Machines
Copyright (C) 2013-2014, International Business Machines
Copyright (C) 1996-2015, International Business Machines
Copyright (C) 1996-2014, International Business Machines
Copyright (C) 2012-2015, International Business Machines
Copyright (C) 2012-2014, International Business Machines
Copyright (C) 2013-2015, International Business Machines
Copyright (C) 2013-2016, International Business Machines
Copyright (C) 1999-2016, International Business Machines
Copyright (C) 1999-2015, International Business Machines
Copyright (C) 1999-2014, International Business Machines
Copyright (C) 2015-2016, International Business Machines Corporation and others.
Copyright (C) 2003 - 2013, International Business Machines Corporation and
Copyright (C) 1999-2011, International Business Machines
Copyright (C) 2005-2016, International Business Machines
Copyright (C) 2005-2012, International Business Machines
Copyright (C) 2005-2015, International Business Machines
Copyright (C) 2005-2013, International Business Machines
Copyright (C) 2005-2014, International Business Machines
Copyright (c) 2004, International Business Machines
Copyright (c) 2004-2014 International Business Machines
Copyright (c) 2004-2014, International Business Machines
Copyright (C) 2013, International Business Machines Corporation
Copyright (C) 1997-2015, International Business Machines Corporation and
Copyright (C) 2016, International Business Machines
Copyright (c) IBM Corporation, 2000-2012. All rights reserved.
Copyright (c) IBM Corporation, 2000-2011. All rights reserved.
Copyright (c) IBM Corporation, 2000-2014. All rights reserved.
Copyright (c) IBM Corporation, 2000-2010. All rights reserved.
Copyright (c) IBM Corporation, 2000-2016. All rights reserved.
Copyright 2010 the V8 project authors. All rights reserved.
Copyright 2006-2008 the V8 project authors. All rights reserved.
Copyright 2012 the V8 project authors. All rights reserved.
Copyright (C) 2008-2016, International Business Machines Corporation and
Copyright (C) 2007-2016, International Business Machines Corporation and
Copyright (C) 2007-2012, International Business Machines Corporation and
Copyright (c) 2001-2011, International Business Machines
Copyright (c) 2001-2007, International Business Machines
Copyright (C) 2010-2014, International Business Machines Corporation and
Copyright (C) 1997-2010, International Business Machines Corporation and
Copyright (C) 1997-2012, International Business Machines Corporation and
Copyright (C) 2009-2015, International Business Machines Corporation and
Copyright (C) 2009-2012, International Business Machines Corporation and
Copyright (c) 2002-2012, International Business Machines Corporation
Copyright (c) 2002-2011, International Business Machines Corporation
Copyright (C) 2008-2013, International Business Machines Corporation and
Copyright (c) 2003-2008, International Business Machines
Copyright (C) 2003-2016, International Business Machines Corporation
Copyright (C) 2003-2014, International Business Machines Corporation
Copyright (C) 2003-2008, International Business Machines Corporation
Copyright (C) 2005-2008, International Business Machines
Copyright (C) 2003-2015, International Business Machines Corporation
Copyright (C) 2003-2009,2012,2016 International Business Machines Corporation and
Copyright (c) 2004-2016, International Business Machines © 2020 and later: Unicode, Inc. and others.
Copyright (C) 2007-2008, International Business Machines Corporation and
Copyright (C) 2001-2007, International Business Machines
Copyright (C) 1997-2012, International Business Machines
Copyright (C) 1997-2015, International Business Machines
Copyright (C) 2001-2010, International Business Machines
Copyright (c) 2000-2005, International Business Machines
Copyright (c) 2000-2007, International Business Machines © 2019 and later: Unicode, Inc. and others.
Copyright (C) 2010-2015, International Business Machines Corporation and
Copyright (C) 2015, International Business Machines Corporation and
Copyright (c) 2003-2013, International Business Machines
Copyright (C) 2001-2012, International Business Machines
Copyright (C) 2001-2011, International Business Machines Corporation
Copyright (C) 2014-2016, International Business Machines
Copyright (C) 1997-2015, International Business Machines Corporation
Copyright (C) 1999-2007, International Business Machines
Copyright (C) 1999-2007, International Business Machines Corporation
Copyright (C) 1999-2011, International Business Machines Corporation
Copyright (C) {1999-2001}, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2002-2016 International Business Machines Corporation and others.
Copyright (C) 2002-2016, International Business Machines Corporation and others.
Copyright (C) 2002-2016 International Business Machines Corporation
Copyright (C) 2002-2015, International Business Machines Corporation and others.
Copyright (C) 2012 International Business Machines Corporation
Copyright (C) 2002-2015 International Business Machines Corporation
Copyright (C) 2004-2015, International Business Machines Corporation and others.
Copyright (C) 2003-2010, International Business Machines Corporation and others.
Copyright (c) 2008-2011, International Business Machines Corporation and
Copyright (c) 2008-2010, International Business Machines Corporation and
Copyright (C) 2014-2016, International Business Machines Corporation and
Copyright (C) 2013, International Business Machines Corporation and
Copyright (c) 2014, International Business Machines
Copyright (C) 2014, International Business Machines
Copyright (C) 2013, International Business Machines
Copyright (C) 2001-2008,2010 IBM and others. All rights reserved.
Copyright (C) 2010 , Yahoo! Inc.
Copyright (c) 1997-2011, International Business Machines Corporation and
Copyright (C) 2013-2014, International Business Machines Corporation and
Copyright (C) 2009-2013, International Business Machines Corporation and
Copyright (C) 1996-2012, International Business Machines Corporation and
Copyright (C) 2015, International Business Machines Corporation
Copyright (c) 2001-2012, International Business Machines Corporation
Copyright (C) 2001-2014 IBM and others. All rights reserved.
Copyright (C) 2008-2014, Google, International Business Machines Corporation and
Copyright (C) 2008, Google, International Business Machines Corporation and
Copyright (C) 2008-2015, Google, International Business Machines Corporation
Copyright (c) 2001-2014, International Business Machines
Copyright (c) 2002-2010, International Business Machines Corporation
Copyright (C) 2011-2015, International Business Machines Corporation and
Copyright (C) 2011-2016, International Business Machines Corporation and
Copyright (C) 2011-2012, International Business Machines Corporation and
Copyright (C) 1996-2016, International Business Machines
Copyright (C) 1998-2014, International Business Machines
Copyright (C) 2004-2016, International Business Machines
Copyright (C) 2010-2011, International Business Machines
Copyright (C) 2009-2015, International Business Machines
Copyright (C) 2015, International Business Machines
Copyright (C) 2012-2016, International Business Machines
Copyright (C) 1999-2012, International Business Machines
Copyright (C) 2001, International Business Machines
Copyright (C) 2013, International Business Machines Corporation and others.
Copyright (C) 2010-2012, International Business Machines
Copyright (C) 2004-2015, International Business Machines
Copyright (C) 2003-2006, International Business Machines
Copyright (C) 2013-2015, International Business Machines Corporation and others.
Copyright (C) 2001-2015 IBM and others. All rights reserved.
Copyright (C) 2008-2015, International Business Machines Corporation
Copyright (C) 2008-2016, International Business Machines
Copyright (C) 2008-2013, International Business Machines Corporation
Copyright (C) 2004-2012, International Business Machines Corporation and
Copyright (C) 1997-2009,2014 International Business Machines
Copyright (C) 2009-2011, International Business Machines Corporation and
Copyright (C) 2009-2016, International Business Machines Corporation and
Copyright (C) 2009-2013, International Business Machines
Copyright (C) 2008-2011, International Business Machines
Copyright (C) 2007-2014, International Business Machines Corporation and
Copyright (C) 2009-2010, International Business Machines Corporation and
Copyright (C) 2001-2016 International Business Machines Corporation
Copyright (c) 2002-2011, International Business Machines
Copyright (C) 2001-2012 IBM, Inc. All Rights Reserved.
Copyright (c) 2013-2016 International Business Machines Corporation and others. All rights reserved.
Copyright (c) 2013-2015 International Business Machines Corporation and others. All rights reserved.
Copyright (c) 2007-2012, International Business Machines Corporation and
Copyright (c) 2007-2012, International Business Machines
Copyright (C) 2010, International Business Machines
Copyright (C) 1997-2011, International Business Machines
Copyright (C) 1997-2005, International Business Machines
Copyright (C) 2009-2011, International Business Machines
Copyright (C) 2003-2015, International Business Machines
Copyright (C) 2009-2016, International Business Machines
Copyright (C) 2008-2012, International Business Machines
Copyright (C) 2008, International Business Machines
Copyright (C) 2011-2014, International Business Machines
Copyright (C) 2011-2013, International Business Machines
Copyright (C) 2005, International Business Machines
Copyright (C) 1999-2013, International Business Machines
Copyright (C) 1998-2016, International Business Machines
Copyright (c) 2007-2014, International Business Machines Corporation and
Copyright (C) 2003-2013, International Business Machines
Copyright (c) 2007-2016, International Business Machines Corporation and
Copyright (c) 2008-2015, International Business Machines
Copyright (C) 1999-2010, International Business Machines
Copyright (C) 2000-2015, International Business Machines
Copyright (C) 2000-2011, International Business Machines
Copyright (C) 2000-2012, International Business Machines
Copyright (C) 2000-2010, International Business Machines
Copyright (C) 2004-2010, International Business Machines
Copyright (C) 2004-2005, International Business Machines
Copyright (c) 2013-2014, International Business Machines
Copyright (c) 1991-2013 Unicode, Inc. © 2019 Unicode®, Inc.
Copyright (C) 2018 and later: Unicode, Inc. and others.
Copyright (c) 2008-2013 International Business Machines
Copyright (C) 2002-2010, International Business Machines
Copyright (c) 2012-2015 International Business Machines © 2020 Unicode®, Inc.
Copyright (c) 2005-2013 IBM Corporation and others. All rights reserved
Copyright (c) 2011-2012, International Business Machines Corporation and
Copyright (C) 1998-2000, International Business Machines © 2017 Unicode®, Inc.
Copyright (c) 2007-2015 International Business Machines
Copyright (C) 2004-2006, International Business Machines
Copyright (C) 2003-2005, International Business Machines
Copyright (c) 1999-2014 International Business Machines
Copyright (c) 2003, International Business Machines
Copyright (C) 2014 International Business Machines
Copyright (c) 2001-2003 International Business Machines
Copyright (c) 2004-2011 International Business Machines
Copyright (C) 2015-2016, International Business Machines
Copyright (c) 2001-2015 International Business Machines
Copyright (C) 2003-2012, International Business Machines Corporation and COPYRIGHT AND PERMISSION NOTICE
Copyright (c) 2003 National Electronics and Computer Technology Center and others
Copyright (C) 2005-2010, International Business Machines
Copyright (c) 2007-2009 IBM Corporation and others. All rights reserved
Copyright (C) 2004-2016 International Business Machines
Copyright (C) 1998-2013, International Business Machines
Copyright (C) 1998-2010, International Business Machines
Copyright (c) 1999-2004, International Business Machines
Copyright (C) 2002-2006 International Business Machines Corporation
Copyright (C) 1999-2006, International Business Machines
Copyright (C) 2002-2016 IBM, Inc. All Rights Reserved.
Copyright (c) 2002-2006, International Business Machines(C) Copyright IBM Corp. 1998-2007 - All Rights Reserved
Copyright (C) 1999-2003, International Business Machines
Copyright (C) 1998-2006, International Business Machines Corporation and
Copyright (C) 1998-2003, International Business Machines Corporation and
Copyright (C) 2003 - 2008, International Business Machines
Copyright (C) 1999-2008, International Business Machines
Copyright (C) 1999-2001, International Business Machines
Copyright (C) 1999-2005, International Business Machines
Copyright (C) 2016 and later: Unicode, Inc. and others.
Copyright (c) 2001-2010 IBM Corporation and others. All Rights Reserved.
Copyright (C) 1998-2005, International Business Machines Corporation and
Copyright (C) 1998-2001, International Business Machines Corporation and
Copyright (c) 2002-2005, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2000-2014, International Business Machines
Copyright (C) 1996-2013, International Business Machines
Copyright (c) 2002-2006, International Business Machines Corporation and
Copyright (c) 2004-2010, International Business Machines Corporation and
Copyright (C) 2004-2011, International Business Machines
Copyright (c) 2002-2005, International Business Machines Corporation and
Copyright (c) 2002-2014, International Business Machines
Copyright (c) 1997-2012, International Business Machines
Copyright (c) 2002-2008, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2011-2013, Apple Inc.; Unicode, Inc.; and others. All Rights Reserved.
Copyright (C) 2011-2013, Apple Inc. and others. All Rights Reserved.
Copyright (c) 2005-2007,2010 Apple Inc., Unicode Inc.,and others. All Rights Reserved.
Copyright (c) 1999-2003, International Business Machines Corporation and
Copyright (c) 2003-2014, International Business Machines
Copyright (c) 2002-2010, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 1999-2010, International Business Machines Corporation and
Copyright (c) 1999-2002, International Business Machines Corporation and
Copyright (C) 2002-2003, International Business Machines
Copyright (C) 2002, International Business Machines
Copyright (c) 2007, International Business Machines Corporation and
Copyright (C) 2007, International Business Machines
Copyright (C) 2001-2006, International Business Machines
Copyright (C) 2010-2014, International Business Machines Corporation and others.
Copyright (C) 2005-2016, International Business Machines Corporation and
Copyright (C) 2015-2016, International Business Machines Corporation and
Copyright (C) 2008-2012, International Business Machines Corporation
Copyright (c) 2006-2015 International Business Machines Corporation and others. All rights reserved.
Copyright (c) 2014-2015 International Business Machines Corporation and others. All rights reserved.
Copyright (C) 2002-2011, International Business Machines
Copyright (c) 2003-2010, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2012 IBM Corporation and Others. All Rights Reserved.
Copyright (C) 1998-2012, International Business Machines Corporation
Copyright (c) 2009, International Business Machines Corporation and
Copyright (C) The Internet Society (2002). All Rights Reserved.
Copyright (c) 2015, International Business Machines Corporation and
Copyright (c) 2002, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 1998-2016, International Business Machines Corporation
Copyright (c) 2011-2016,International Business Machines
Copyright (C) 2012 International Business Machines Corporation and Others. All Rights Reserved.
Copyright (C) 2011, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2011, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 2011-2012,International Business Machines
Copyright (c) 2007, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2007-2007, International Business Machines(C) Copyright IBM Corp. 1998-2014 - All Rights Reserved
Copyright (C) 1998-2002, International Business Machines
Copyright (c) 2001-2007, International Business Machines Corporation and others. All Rights Reserved.(C) Copyright IBM Corp. 1998-2013 - All Rights Reserved
Copyright (C) 1998-2015, International Business Machines
Copyright (C) 2001-2014 International Business Machines
Copyright (C) 2011-2016, International Business Machines
Copyright (C) 2011-2015, International Business Machines
Copyright (c) 1999-2014, International Business Machines Corporation and
Copyright (c) 1999-2009, International Business Machines Corporation and
Copyright (c) 2010,International Business Machines
Copyright (c) 2010-2016,International Business Machines
Copyright (c) 2002-2005, International Business Machines
Copyright (C) 2000-2003, International Business Machines
Copyright (c) 2008-2014, International Business Machines Corporation and
Copyright (C) 2001 - 2005, International Business Machines
Copyright (C) 2001-2005, International Business Machines
Copyright (C) 1995-2014, International Business Machines
Copyright (c) 2000-2004 IBM, Inc. and Others.
Copyright (c) 2002-2014, International Business Machines Corporation and
Copyright (c) 2007-2013, International Business Machines Corporation and
Copyright (c) 2002-2012, International Business Machines Corporation and
Copyright (C) 2002-2012, International Business Machines
Copyright (C) 2009-2011, International Business Machines Corporation, Google and Others.
Copyright (c) 2002, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2009-2014, International Business Machines
Copyright (C) 2008, International Business Machines Corporation and others.
Copyright (C) 2000-2016, International Business Machines
Copyright (C) 2011-2014 International Business Machines
Copyright (C) 1997-2014, International Business Machines
Copyright (C) 1997-2013, International Business Machines
Copyright (c) 2004-2006, International Business Machines
Copyright (C) 1997-2016, International Business Machines
Copyright (C) 1997-2006, International Business Machines
Copyright (C) 1997-2011, International Business Machines Corporation and others.
Copyright (C) 1997-2013, International Business Machines Corporation and others.
Copyright (c) 2004-2015, International Business Machines
Copyright (C) 2009-2017, International Business Machines Corporation,Google, and others. All Rights Reserved.
Copyright (C) 1997-2016, International Business Machines Corporation and others.
Copyright (C) 2008-2015, International Business Machines Corporation and
Copyright (C) 1997-2015, International Business Machines Corporation and others.
Copyright (C) 2014-2016, International Business Machines Corporation and others.
Copyright (c) 2014-2016, International Business Machines
Copyright (C) 2001-2011 IBM and others. All rights reserved.
Copyright (C) 1996-2014, International Business Machines Corporation and others.
Copyright (C) 1996-2016, International Business Machines Corporation and
Copyright (C) 2009-2016, International Business Machines Corporation,
Copyright (C) 2009-2010, Google, International Business Machines Corporation and
Copyright (C) 2008-2014, Google, International Business Machines Corporation
Copyright (C) 1996-2015, International Business Machines Corporation and
Copyright (c) 1996-2015, International Business Machines Corporation and others.
Copyright (C) 2010-2012,2015 International Business Machines
Copyright (C) 2007-2015, International Business Machines
Copyright (C) 2013-2014, International Business Machines Corporation and others.
Copyright (C) 2010-2013, International Business Machines
Copyright (c) 2002-2005, International Business Machines Corporation
Copyright (C) 2001-2011,2014 IBM and others. All rights reserved.
Copyright (C) 2008-2016, International Business Machines Corporation
Copyright (C) 2004 - 2008, International Business Machines Corporation and
Copyright (C) 1997-2011,2014-2015 International Business Machines
Copyright (C) 2001-2003, International Business Machines
Copyright (C) 1999-2009, International Business Machines
Copyright (C) 2020 and later: Unicode, Inc. and others.
Copyright (c) 2002, International Business Machines Corporation and
Copyright (C) 2000-2008, International Business Machines
Copyright (C) 1998-2006, International Business Machines
Copyright (C) 1998-2001, International Business Machines Corporation
Copyright (C) 1998-2004, International Business Machines Corporation
Copyright (C) 2000, International Business Machines
Copyright (c) 1999-2016, International Business Machines Corporation and
Copyright (c) 2015, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 1999-2012, International Business Machines Corporation and
Copyright (C) 1998-2011, International Business Machines
Copyright (C) 2008-2014, International Business Machines Corporation and
Copyright (C) 2003-2004, International Business Machines
Copyright (c) 2003-2005, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2002-2006 IBM, Inc. All Rights Reserved.
Copyright (C) 2004-2008, International Business Machines
Copyright (c) 2002-2016 International Business Machines Corporation and
Copyright (c) 2002-2015, International Business Machines Corporation and
Copyright (C) 2002-2016, International Business Machines Corporation
Copyright (c) 2002-2010,International Business Machines
Copyright (c) 2002-2014,International Business Machines
Copyright (c) 2002-2016,International Business Machines
Copyright (C) 2016 International Business Machines Corporation
Copyright © 2019 and later: Unicode, Inc. and others.
Copyright (c) 2016, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 2016 International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 2015-2016, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 2005-2006, International Business Machines Corporation and
Copyright (c) 1997-2004, International Business Machines Corporation
Copyright (c) 2012-2016, International Business Machines Corporation
Copyright (c) 2012-2014, International Business Machines Corporation and
Copyright (c) 1997-2014, International Business Machines Corporation
Copyright (c) 1996-2016, International Business Machines Corporation and
Copyright (c) 2003-2013, International Business Machines Corporation
Copyright (c) 2003-2008, International Business Machines Corporation
Copyright (c) 1997-2015, International Business Machines Corporation
Copyright (c) 2002-2016, International Business Machines Corporation and
Copyright (c) 1997-2002, International Business Machines Corporation and
Copyright (C) 1996-2012, International Business Machines
Copyright (c) 1997-2013 International Business Machines Corporation and
Copyright (c) 2010-2012, International Business Machines Corporation and
Copyright (c) 1997-2011, International Business Machines Corporation
Copyright (c) 1997-2006, International Business Machines Corporation and
Copyright (c) 2008-2016 International Business Machines Corporation and
Copyright (c) 2008-2016, International Business Machines Corporation and
Copyright (c) 1997-2016 International Business Machines Corporation and
Copyright (c) 2007-2011, International Business Machines
Copyright (c) 2007-2010, International Business Machines
Copyright (C) 2001-2016, International Business Machines Corporation and
Copyright (C) 2001-2003, International Business Machines Corporation and
Copyright (C) 2003-2011, International Business Machines
Copyright (c) 1997-2007, International Business Machines Corporation and
Copyright (c) 1997-2015, International Business Machines
Copyright (C) 2004-2009, International Business Machines Corporation and
Copyright (C) 2004, International Business Machines Corporation and
Copyright (C) 1996-2009, International Business Machines Corporation and
Copyright (C) 1996-2006, International Business Machines Corporation and
Copyright (C) 2011-2013, International Business Machines Corporation
Copyright (C) 2000-2007, International Business Machines
Copyright (c) 2001, International Business Machines Corporation and
Copyright (C) 2012-2013, International Business Machines
Copyright (c) 2010-2016, International Business Machines Corporation and
Copyright (c) 2010-2016, International Business Machines Corporation
Copyright (c) 1997-2010, International Business Machines Corporation
Copyright (c) 1997-2003, International Business Machines
Copyright (C) 2014-2015, International Business Machines Corporation and
Copyright (c) 1997-2013, International Business Machines Corporation
Copyright (c) 1999-2016, International Business Machines
Copyright (c) 1999-2016 International Business Machines Corporation and
Copyright (c) 2016, International Business Machines Corporation and
Copyright (c) 2016, International Business Machines
Copyright (c) 2013-2016, International Business Machines Corporation
Copyright (c) 2013, International Business Machines Corporation
Copyright (C) 2013-2016, International Business Machines Corporation and
Copyright (c) 2001-2010, International Business Machines Corporation and
Copyright (C) 2014, International Business Machines Corporation and
Copyright (c) 1999-2015, International Business Machines Corporation and
Copyright (C) 2001-2016, International Business Machines orporation
Copyright (c) 2001-2008, International Business Machines Corporation and others
Copyright (C) 2003-2016, International Business Machines Corporation and
Copyright (c) 2004, International Business Machines Corporation
Copyright (C) 2001-2009, International Business Machines
Copyright (c) 2004,2011 International Business Machines
Copyright (c) 2004-2011, International Business Machines
Copyright (c) 2000-2016, International Business Machines Corporation
Copyright (c) 2001-2005, International Business Machines Corporation and
Copyright (C) 2001-2004, International Business Machines
Copyright (c) 2001-2009, International Business Machines
Copyright (c) 1997-2009, International Business Machines Corporation
Copyright (c) 1997-2013, International Business Machines
Copyright (c) 1997-2012, International Business Machines Corporation
Copyright (C) 2007-2015, International Business Machines Corporation and
Copyright (C) 2007-2011, International Business Machines Corporation and
Copyright (C) 2007, International Business Machines Corporation and
Copyright (c) 1998-2005, International Business Machines Corporation and
Copyright (c) 2002-2010, International Business Machines Corporation and
Copyright (C) 1999-2016 International Business Machines Corporation and
Copyright (c) 2004-2011, International Business Machines Corporation and
Copyright (c) 2002-2007, International Business Machines Corporation and
Copyright (C) 2003, International Business Machines Corporation and
Copyright (C) 2005-2011, International Business Machines
Copyright (C) 2011-2012, International Business Machines
Copyright (C) 2007-2012, International Business Machines
Copyright (C) 2006-2016, International Business Machines Corporation
Copyright (C) 2006-2012, International Business Machines Corporation and others.
Copyright 2007 Google Inc. All Rights Reserved.
Copyright (c) 2001-2015, International Business Machines
Copyright (C) 2006-2014, International Business Machines Corporation
Copyright (C) 2008, International Business Machines Corporation and
Copyright (C) 2009-2012, International Business Machines
Copyright (C) 2006 International Business Machines Corporation
Copyright (C) 2010-2016, International Business Machines Corporation and
Copyright (C) 2002-2014, International Business Machines Corporation and
Copyright (C) 2002-2005, International Business Machines Corporation and
Copyright (C) 2011, International Business Machines
Copyright (c) 2003-2010 International Business Machines
Copyright (C) 2003-2003, International Business Machines
Copyright (C) 1999-2016 International Business Machines Corporation
Copyright (C) 1999-2014 International Business Machines Corporation
Copyright (C) 1999-2014 International Business Machines
Copyright (C) 2002-2011, International Business Machines Corporation and others.
Copyright (C) 2002-2008, International Business Machines Corporation and others.
Copyright (C) 2002-2008 International Business Machines Corporation
Copyright (c) 2001-2005, International Business Machines
Copyright (C) 2002-2014 International Business Machines Corporation
Copyright (c) 2003-2011, International Business Machines
Copyright (C) 1998-2012, International Business Machines Corporation and
Copyright (C) 2001-2014, International Business Machines Corporation.
Copyright (C) 2001-2011, International Business Machines Corporation.
Copyright (C) 2001-2014, International Business Machines Corporation and
Copyright (C) 2001-2011, International Business Machines Corporation and
Copyright (C) 2001-2012, International Business Machines Corporation and
Copyright 2004 and onwards Google Inc.
Copyright (C) 2004-2014, International Business Machines
Copyright (C) 2006, International Business Machines
Copyright (C) 2004-2012, International Business Machines
Copyright (C) 2001-2013, International Business Machines
Copyright (C) 1998-2004, International Business Machines
Copyright (C) 2000-2013, International Business Machines
Copyright (C) 1999-2015 International Business Machines
Copyright (C) 2000-2006, International Business Machines
Copyright (C) 1999-2004, International Business Machines
Copyright (C) 2003-2007, International Business Machines
Copyright (C) 2002-2006, International Business Machines
Copyright (C) 2001-2015, International Business Machines
Copyright (c) 2001-2012, International Business Machines
Copyright (c) 2002-2004, International Business Machines
Copyright (C) 1999-2016, International Business Machines Corporation and
Copyright (c) 1996-2014, International Business Machines
Copyright (C) 1999-2016, International Business Machines Corporation
Copyright (C) 2009-2014 International Business Machines
Copyright (C) 2004-2007, International Business Machines
Copyright (c) 2001-2016, International Business Machines
Copyright (C) 2003-2009, International Business Machines
Copyright (C) 1999-2013, International Business Machines Corporation and
Copyright (C) 1999-2015, International Business Machines Corporation and
Copyright (c) 2002-2011, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 2001-2016 IBM, Inc. All Rights Reserved.
Copyright (C) 1999-2016 International Business Machines
Copyright (C) 2009-2010 IBM Corporation and Others. All Rights Reserved.
Copyright (C) 1998-2012, International Business Machines
Copyright (C) 1991 and later: Unicode, Inc. and others.
Copyright (C) 1997-2000, International Business Machines
Copyright (c) 1999-2007, International Business Machines Corporation and
Copyright (c) 2000 IBM, Inc. and Others.
Copyright (C) 2008-2013, International Business Machines
Copyright (C) 1998-2003, 2006, International Business Machines Corporation
Copyright (c) 2002-2003,International Business Machines
Copyright (C) 2009 International Business Machines
Copyright (C) 2010-2016 International Business Machines
Copyright (C) 2008-2012 IBM, Inc. All Rights Reserved.
Copyright (C) 1998-2008, International Business Machines
Copyright (C) 2010-2016, International Business Machines
Copyright (C) 1999-2006,2013 IBM Corp. All rights reserved.
Copyright (C) 2008-2009, International Business Machines Corporation and
Copyright (C) 2012,2014 International Business Machines
Copyright (c) 1996-2015, International Business Machines Corporation and
Copyright (C) 1997-2005, International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 1999-2012, International Business Machines Corporation and
Copyright (C) 1996-2013, International Business Machines Corporation
Copyright (C) 1998-2005, International Business Machines
Copyright 2001 and onwards Google Inc.
Copyright (C) 2010-2012,2014, International Business Machines
Copyright (C) 1996-2015, International Business Machines Corporation and others.
Copyright (c) 2003-2004, International Business Machines
Copyright (C) 2000-2004, International Business Machines
Copyright (C) 2002-2013, International Business Machines
Copyright (C) 2002-2011 International Business Machines Corporation and others. All Rights Reserved.
Copyright (C) 1999-2010, International Business Machines Corporation and others.
Copyright (C) 2001-2005, International Business Machines Corporation and others. All Rights Reserved.
Copyright (c) 1996-2016, International Business Machines Corporation
Copyright (C) 1997-2010, International Business Machines
Software: libtiff 4.1.0
Copyright notice:
Copyright © 2015 Open Microscopy Environment / University of Dundee
Copyright (c) 2004, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) 1990-1997 Sam Leffler
Copyright (c) 1991-1997 Silicon Graphics, Inc.
Copyright (c) 1988-1997 Sam Leffler
Copyright (c) 1991-1997 Sam Leffler
Use and Copyright
Copyright (C) 1990, 1995 Frank D. Cringle.
Copyright (c) 1994-1997 Sam Leffler
Copyright (c) 1994-1997 Silicon Graphics, Inc.
Copyright (c) 1997 Greg Ward Larson
Copyright (c) 1997 Silicon Graphics, Inc.
Copyright (c) 2010, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) Joris Van Damme <info@awaresystems.be>
Copyright (c) AWare Systems <http:www.awaresystems.be/>
Copyright (c) 1996-1997 Sam Leffler
Copyright (c) 1996 Pixar
Copyright (c) 1995-1997 Sam Leffler
Copyright (c) 1995-1997 Silicon Graphics, Inc.
Copyright (c) 1988-1996 Sam Leffler
Copyright (c) 1991-1996 Silicon Graphics, Inc.
Copyright (c) 1992-1997 Sam Leffler
Copyright (c) 1992-1997 Silicon Graphics, Inc.
Copyright (c) 2018, Mapbox
Copyright (c) 2017, Planet Labs
Copyright (c) 1990 by Sun Microsystems, Inc.
Copyright 1990 by Digital Equipment Corporation, Maynard, Massachusetts.
Copyright 1991 by Digital Equipment Corporation, Maynard, Massachusetts.
Copyright (c) 2002, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) 2003 Ross Finlayson
Additions (c) Richard Nolde 2006-2010
Copyright (c) 2003, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) 2000, Frank Warmerdam
Copyright (c) 1987, 1993, 1994
Copyright (c) 1989, 1993
Copyright (c) 2009 Frank Warmerdam
Copyright (c) 1987, 1993
Copyright (c) 2005 The DragonFly Project. All rights reserved.
Copyright (c) 2003 Citrus Project,
All rights reserved.
Copyright (c) 1990, 1993
Copyright (c) 1996 Mike Johnson
Copyright (c) 1996 BancTec AB
Copyright (c) 2004, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) 2012, Frank Warmerdam <warmerdam@pobox.com>
Copyright (c) 2019, Even Rouault <even.rouault at spatialys.com>
Copyright (c) 2007, Frank Warmerdam <warmerdam@pobox.com>
Copyright (c) 2019, Thomas Bernard <miniupnp@free.fr>
Copyright (c) 2008, Andrey Kiselev <dron@ak4719.spb.edu>
Copyright (c) 1999, Frank Warmerdam
Copyright (c) 1991-1996 Sam Leffler
Copyright (c) 1996 USAF Phillips Laboratory
Software: opencv 4.2.0
Copyright notice:
Copyright (C) 2016, NVIDIA Corporation, all rights reserved.

1
akg Submodule

@ -0,0 +1 @@
Subproject commit c460176523d039c8995f1d71089753725ebc0792

View File

@ -49,10 +49,11 @@ usage()
echo " -Q Enable dump memory, default off"
echo " -D Enable dumping of function graph ir, default on"
echo " -z Compile dataset & mindrecord, default on"
echo " -M Enable MPI and NCCL for GPU training, default on"
echo " -M Enable MPI and NCCL for GPU training, gpu default on"
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
echo " -I Compile predict, default off"
echo " -K Compile with AKG, default off"
echo " -s Enable serving module, default off"
}
# check value of input is 'on' or 'off'
@ -86,15 +87,15 @@ checkopts()
ENABLE_DUMPE2E="off"
ENABLE_DUMP_IR="on"
COMPILE_MINDDATA="on"
ENABLE_MPI="on"
ENABLE_MPI="off"
CUDA_VERSION="9.2"
COMPILE_PREDICT="off"
USE_GLOG="on"
PREDICT_PLATFORM=""
ENABLE_AKG="off"
ENABLE_AKG="on"
ENABLE_SERVING="off"
# Process the options
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:s' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
@ -168,6 +169,7 @@ checkopts()
if [[ "X$OPTARG" == "Xgpu" ]]; then
ENABLE_GPU="on"
ENABLE_CPU="on"
ENABLE_MPI="on"
elif [[ "X$OPTARG" == "Xd" || "X$OPTARG" == "Xascend" ]]; then
ENABLE_D="on"
ENABLE_CPU="on"
@ -234,6 +236,10 @@ checkopts()
ENABLE_AKG="on"
echo "enable compile with akg"
;;
s)
ENABLE_SERVING="on"
echo "enable serving"
;;
*)
echo "Unknown option ${opt}!"
usage
@ -242,9 +248,12 @@ checkopts()
done
}
checkopts "$@"
echo "---------------- mindspore: build start ----------------"
echo "---------------- MindSpore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
git submodule update --init --recursive akg
fi
build_exit()
{
@ -307,9 +316,13 @@ build_mindspore()
if [[ "X$USE_GLOG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
fi
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
fi
if [[ "X$ENABLE_SERVING" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON"
fi
echo "${CMAKE_ARGS}"
if [[ "X$INC_BUILD" = "Xoff" ]]; then
cmake ${CMAKE_ARGS} ../..

View File

@ -36,6 +36,7 @@ elseif (DEFINED ENV{D_LINK_PATH})
find_library(hccl libhccl.so ${GE_LIB_PATH})
find_library(cce libcce.so ${GE_LIB_PATH})
find_library(resource libresource.so ${GE_LIB_PATH})
find_library(error_manager liberror_manager.so ${GE_LIB_PATH})
else()
# Ascend mode
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
@ -54,6 +55,7 @@ else()
find_library(msprof libmsprof.so ${ASCEND_RUNTIME_PATH})
find_library(register libregister.so ${ASCEND_RUNTIME_PATH})
find_library(resource libresource.so ${ASCEND_RUNTIME_PATH})
find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_PATH})
endif()
# compile libraries from following directories

View File

@ -1,4 +1,4 @@
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
mindspore_add_pkg(gtest
VER 1.8.0

View File

@ -0,0 +1,19 @@
set(LIB_ICU_COMMON icuuc)
set(LIB_ICU_DATA icudata)
set(LIB_ICU_I18N icui18n)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
message("icu4c thirdparty do not support windows currently.")
else()
mindspore_add_pkg(icu4c
VER 67.1
LIBS ${LIB_ICU_COMMON} ${LIB_ICU_DATA} ${LIB_ICU_I18N}
URL https://github.com/unicode-org/icu/archive/release-67-1.tar.gz
MD5 0c2662a2b0bc80b0eb56495205247c8f
CONFIGURE_COMMAND ./icu4c/source/runConfigureICU Linux --enable-rpath --disable-tests --disable-samples --disable-icuio --disable-extras ICU_DATA_FILTER_FILE=${CMAKE_SOURCE_DIR}/third_party/icu4c/filter.json
)
include_directories(${icu4c_INC})
add_library(mindspore::icuuc ALIAS icu4c::${LIB_ICU_COMMON})
add_library(mindspore::icudata ALIAS icu4c::${LIB_ICU_DATA})
add_library(mindspore::icui18n ALIAS icu4c::${LIB_ICU_I18N})
add_definitions(-D ENABLE_ICU4C)
endif()

View File

@ -8,7 +8,7 @@ elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-attributes -Wno-unknown-pragmas")
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-unused-value -Wno-implicit-fallthrough")
else()
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
endif()

View File

@ -1,9 +1,12 @@
set(protobuf_USE_STATIC_LIBS ON)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
else()
elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
else()
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
endif()
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})

View File

@ -1,10 +1,10 @@
if (WIN32)
mindspore_add_pkg(sqlite
VER 3.31.1
VER 3.32.2
LIBS sqlite3
URL https://sqlite.org/2020/sqlite-amalgamation-3310100.zip
MD5 2b7bfcdd97dc281903a9aee966213fe4
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch001 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch002 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch003
URL https://sqlite.org/2020/sqlite-amalgamation-3320200.zip
MD5 1eccea18d248eb34c7378b2b3f63f1db
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch001
CMAKE_OPTION " "
)
@ -18,11 +18,11 @@ else ()
endif()
set(sqlite_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
mindspore_add_pkg(sqlite
VER 3.31.1
VER 3.32.2
LIBS sqlite3
URL https://github.com/sqlite/sqlite/archive/version-3.31.1.tar.gz
MD5 5f4e7b4016c15f4fb5855615279819da
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch002 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch003
URL https://github.com/sqlite/sqlite/archive/version-3.32.2.tar.gz
MD5 ea6d3b3289b4ac216fb06081a01ef101
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001
CONFIGURE_COMMAND ./configure --enable-shared=no --disable-tcl --disable-editline --enable-json1)
endif ()

View File

@ -26,6 +26,9 @@ include_directories(${Python3_INCLUDE_DIRS})
include_directories(${CMAKE_SOURCE_DIR}/third_party)
if (ENABLE_CPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
if (ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
endif()
endif()
if (ENABLE_GPU)
@ -36,7 +39,6 @@ if (ENABLE_GPU)
if (ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
endif()
endif()
@ -52,6 +54,7 @@ elseif(ENABLE_D OR ENABLE_TESTCASES)
endif()
if (ENABLE_MINDDATA)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/icu4c.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)

View File

@ -91,7 +91,20 @@ if (ENABLE_MINDDATA)
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
message("icu4c does not support windows system temporarily")
else()
file(GLOB_RECURSE ICU4C_LIB_LIST
${icu4c_LIBPATH}/libicuuc*
${icu4c_LIBPATH}/libicudata*
${icu4c_LIBPATH}/libicui18n*
)
install(
FILES ${ICU4C_LIB_LIST}
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif()
endif ()
if (ENABLE_CPU)
@ -109,19 +122,20 @@ if (ENABLE_CPU)
)
endif ()
if (ENABLE_MPI)
install(
TARGETS _ms_mpi
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
endif ()
if (ENABLE_GPU)
if (ENABLE_MPI)
install(
TARGETS _ms_mpi
DESTINATION ${INSTALL_BASE_DIR}
COMPONENT mindspore
)
install(
TARGETS gpu_collective
DESTINATION ${INSTALL_LIB_DIR}
COMPONENT mindspore
)
endif ()
install(
TARGETS gpu_queue
DESTINATION ${INSTALL_LIB_DIR}
@ -222,6 +236,16 @@ if (ENABLE_GPU)
endif ()
endif ()
if (ENABLE_D AND ENABLE_AKG)
set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg)
install(
DIRECTORY
${AKG_PATH}/akg
DESTINATION ${INSTALL_PY_DIR}/..
COMPONENT mindspore
)
endif ()
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
install(
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset

View File

@ -51,7 +51,7 @@ endif ()
# get git commit id
set(GIT_COMMIT_ID "")
execute_process(
COMMAND ${GIT} log --format='[sha1]:%h,[branch]:%d' -1
COMMAND ${GIT} log --format='[sha1]:%h,[branch]:%d' --abbrev=8 -1
OUTPUT_VARIABLE GIT_COMMIT_ID
WORKING_DIRECTORY ${MS_ROOT_DIR}
ERROR_QUIET)

View File

@ -1,106 +0,0 @@
# Googlenet Example
## Description
This example is for Googlenet model training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the CIFAR-10 binary version dataset.
> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows:
> ```
> .
> ├── cifar-10-batches-bin # train dataset
> └── cifar-10-verify-bin # infer dataset
> ```
## Running the Example
### Training
```
python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `out.train.log`.
After training, you'll get some checkpoint files under the script folder by default.
You will get the loss value as following:
```
# grep "loss is " out.train.log
epoch: 1 step: 390, loss is 1.4842823
epcoh: 2 step: 390, loss is 1.0897788
...
```
### Evaluation
```
python eval.py --data_path=your_data_path --device_id=6 --checkpoint_path=./train_googlenet_cifar10-125-390.ckpt > out.eval.log 2>&1 &
```
The above python command will run in the background, you can view the results through the file `out.eval.log`.
You will get the accuracy as following:
```
# grep "result: " out.eval.log
result: {'acc': 0.934}
```
### Distribute Training
```
sh run_distribute_train.sh rank_table.json your_data_path
```
The above shell script will run distribute training in the background, you can view the results through the file `train_parallel[X]/log`.
You will get the loss value as following:
```
# grep "result: " train_parallel*/log
train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
...
train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
...
...
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
## Usage:
### Training
```
usage: train.py [--device_target TARGET][--data_path DATA_PATH]
[--device_id DEVICE_ID]
parameters/options:
--device_target the training backend type, default is Ascend.
--data_path the storage path of dataset
--device_id the device which used to train model.
```
### Evaluation
```
usage: eval.py [--device_target TARGET][--data_path DATA_PATH]
[--device_id DEVICE_ID][--checkpoint_path CKPT_PATH]
parameters/options:
--device_target the evaluation backend type, default is Ascend.
--data_path the storage path of datasetd
--device_id the device which used to evaluate model.
--checkpoint_path the checkpoint file path used to evaluate model.
```
### Distribute Training
```
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]
parameters/options:
MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path.
DATA_PATH the storage path of dataset.
```

View File

@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
1. Download and prepare the Cora dataset as required.
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
2. Edit write_cora.sh and modify the parameters
```
--mindrecord_file: output MindRecord file.

View File

@ -15,29 +15,27 @@
"""
User-defined API for MindRecord GNN writer.
"""
import csv
import os
import pickle as pkl
import numpy as np
import scipy.sparse as sp
from mindspore import log as logger
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
args = os.environ['graph_api_args'].split(':')
CITESEER_CONTENT_FILE = args[0]
CITESEER_CITES_FILE = args[1]
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
CITESEER_PATH = args[0]
dataset_str = 'citeseer'
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
edge_profile = (0, [], [])
node_ids = []
def _normalize_citeseer_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
@ -46,6 +44,14 @@ def _normalize_citeseer_features(features):
return features
def _parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def yield_nodes(task_id=0):
"""
Generate node data
@ -53,30 +59,47 @@ def yield_nodes(task_id=0):
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
label_types = {}
label_size = 0
node_num = 0
with open(CITESEER_CONTENT_FILE) as content_file:
content_reader = csv.reader(content_file, delimiter='\t')
line_count = 0
for row in content_reader:
if not row[-1] in label_types:
label_types[row[-1]] = label_size
label_size += 1
if not row[0] in node_id_map:
node_id_map[row[0]] = node_num
node_num += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node
line_count += 1
print('Processed {} lines for nodes.'.format(line_count))
# print('label types {}.'.format(label_types))
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
logger.info("Node task is {}".format(task_id))
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
objects = []
for name in names:
with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f:
objects.append(pkl.load(f, encoding='latin1'))
x, y, tx, ty, allx, ally = tuple(objects)
test_idx_reorder = _parse_index_file(
"{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str))
test_idx_range = np.sort(test_idx_reorder)
tx = _normalize_citeseer_features(tx)
allx = _normalize_citeseer_features(allx)
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range-min(test_idx_range), :] = tx
tx = tx_extended
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
ty_extended[test_idx_range-min(test_idx_range), :] = ty
ty = ty_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
features = features.A
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
line_count = 0
for i, label in enumerate(labels):
if not 1 in label.tolist():
continue
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
'feature_2': label.tolist().index(1)}
line_count += 1
node_ids.append(i)
yield node
logger.info('Processed {} lines for nodes.'.format(line_count))
def yield_edges(task_id=0):
@ -86,24 +109,21 @@ def yield_edges(task_id=0):
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
# print(map_string_int)
with open(CITESEER_CITES_FILE) as cites_file:
cites_reader = csv.reader(cites_file, delimiter='\t')
logger.info("Edge task is {}".format(task_id))
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
graph = pkl.load(f, encoding='latin1')
line_count = 0
for row in cites_reader:
if not row[0] in node_id_map:
print('Source node {} does not exist.'.format(row[0]))
continue
if not row[1] in node_id_map:
print('Destination node {} does not exist.'.format(row[1]))
continue
line_count += 1
edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
yield edge
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)
print('Processed {} lines for edges.'.format(line_count))
for i in graph:
for dst_id in graph[i]:
if not i in node_ids:
logger.info('Source node {} does not exist.'.format(i))
continue
if not dst_id in node_ids:
logger.info('Destination node {} does not exist.'.format(
dst_id))
continue
edge = {'id': line_count,
'src_id': i, 'dst_id': dst_id, 'type': 0}
line_count += 1
yield edge
logger.info('Processed {} lines for edges.'.format(line_count))

View File

@ -15,29 +15,24 @@
"""
User-defined API for MindRecord GNN writer.
"""
import csv
import os
import pickle as pkl
import numpy as np
import scipy.sparse as sp
# parse args from command line parameter 'graph_api_args'
# args delimiter is ':'
args = os.environ['graph_api_args'].split(':')
CORA_CONTENT_FILE = args[0]
CORA_CITES_FILE = args[1]
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
node_id_map = {}
CORA_PATH = args[0]
dataset_str = 'cora'
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
edge_profile = (0, [], [])
def _normalize_cora_features(features):
features = np.array(features)
row_sum = np.array(features.sum(1))
r_inv = np.power(row_sum * 1.0, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
@ -46,6 +41,14 @@ def _normalize_cora_features(features):
return features
def _parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def yield_nodes(task_id=0):
"""
Generate node data
@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
label_types = {}
label_size = 0
node_num = 0
with open(CORA_CONTENT_FILE) as content_file:
content_reader = csv.reader(content_file, delimiter=',')
line_count = 0
for row in content_reader:
if line_count == 0:
line_count += 1
continue
if not row[0] in node_id_map:
node_id_map[row[0]] = node_num
node_num += 1
if not row[-1] in label_types:
label_types[row[-1]] = label_size
label_size += 1
raw_features = [[int(x) for x in row[1:-1]]]
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
'feature_2': [label_types[row[-1]]]}
yield node
line_count += 1
names = ['tx', 'ty', 'allx', 'ally']
objects = []
for name in names:
with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f:
objects.append(pkl.load(f, encoding='latin1'))
tx, ty, allx, ally = tuple(objects)
test_idx_reorder = _parse_index_file(
"{}/ind.{}.test.index".format(CORA_PATH, dataset_str))
test_idx_range = np.sort(test_idx_reorder)
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
features = _normalize_cora_features(features)
features = features.A
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
line_count = 0
for i, label in enumerate(labels):
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
'feature_2': label.tolist().index(1)}
line_count += 1
yield node
print('Processed {} lines for nodes.'.format(line_count))
print('label types {}.'.format(label_types))
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
for k in label_types:
print(k + ',' + str(label_types[k]), file=f)
def yield_edges(task_id=0):
@ -90,24 +93,13 @@ def yield_edges(task_id=0):
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
with open(CORA_CITES_FILE) as cites_file:
cites_reader = csv.reader(cites_file, delimiter=',')
with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f:
graph = pkl.load(f, encoding='latin1')
line_count = 0
for row in cites_reader:
if line_count == 0:
for i in graph:
for dst_id in graph[i]:
edge = {'id': line_count,
'src_id': i, 'dst_id': dst_id, 'type': 0}
line_count += 1
continue
if not row[0] in node_id_map:
print('Source node {} does not exist.'.format(row[0]))
continue
if not row[1] in node_id_map:
print('Destination node {} does not exist.'.format(row[1]))
continue
edge = {'id': line_count,
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
yield edge
line_count += 1
yield edge
print('Processed {} lines for edges.'.format(line_count))
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
for k in node_id_map:
print(k + ',' + str(node_id_map[k]), file=f)

View File

@ -16,6 +16,7 @@
Graph data convert tool for MindRecord.
"""
import numpy as np
from mindspore import log as logger
__all__ = ['GraphMapSchema']
@ -41,6 +42,7 @@ class GraphMapSchema:
"edge_feature_index": {"type": "int32", "shape": [-1]}
}
@property
def get_schema(self):
"""
Get schema
@ -52,6 +54,7 @@ class GraphMapSchema:
Set node features profile
"""
if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Node feature profile is not match.")
raise ValueError("Node feature profile is not match.")
self.num_node_features = num_features
@ -66,6 +69,7 @@ class GraphMapSchema:
Set edge features profile
"""
if num_features != len(features_data_type) or num_features != len(features_shape):
logger.info("Edge feature profile is not match.")
raise ValueError("Edge feature profile is not match.")
self.num_edge_features = num_features
@ -83,6 +87,10 @@ class GraphMapSchema:
Returns:
graph data with union schema
"""
if node is None:
logger.info("node cannot be None.")
raise ValueError("node cannot be None.")
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
"node_feature_index": []}
for i in range(self.num_node_features):
@ -117,6 +125,10 @@ class GraphMapSchema:
Returns:
graph data with union schema
"""
if edge is None:
logger.info("edge cannot be None.")
raise ValueError("edge cannot be None.")
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
"type": edge["type"], "edge_feature_index": []}

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord GNN writer.
"""
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
[348, 347], [347, 351], [347, 327], [347, 329], [347, 331],
[347, 335], [347, 341], [347, 345], [347, 346], [346, 335],
[346, 340], [346, 339], [346, 349], [346, 353], [346, 354],
[346, 341], [346, 345], [345, 335], [345, 336], [345, 341],
[344, 338], [344, 342], [343, 332], [343, 338], [343, 342],
[342, 332], [340, 349], [334, 349], [333, 349], [330, 349],
[328, 349], [359, 349], [358, 352], [358, 349], [358, 354],
[358, 356], [357, 350], [357, 354], [357, 356], [356, 350],
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
# profile: (num_features, feature_data_types, feature_shapes)
node_profile = (0, [], [])
edge_profile = (0, [], [])
def yield_nodes(task_id=0):
"""
Generate node data
Yields:
data (dict): data row which is dict.
"""
print("Node task is {}".format(task_id))
node_list = []
for edge in social_data:
src, dst = edge
if src not in node_list:
node_list.append(src)
if dst not in node_list:
node_list.append(dst)
node_list.sort()
print(node_list)
for node_id in node_list:
node = {'id': node_id, 'type': 1}
yield node
def yield_edges(task_id=0):
"""
Generate edge data
Yields:
data (dict): data row which is dict.
"""
print("Edge task is {}".format(task_id))
line_count = 0
for undirected_edge in social_data:
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[0],
'dst_id': undirected_edge[1],
'type': 1}
yield edge
line_count += 1
edge = {
'id': line_count,
'src_id': undirected_edge[1],
'dst_id': undirected_edge[0],
'type': 1}
yield edge

View File

@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \
--graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites"
--graph_api_args "$SRC_PATH"

View File

@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 18 \
--mindrecord_page_size_by_bit 20 \
--graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv"
--graph_api_args "$SRC_PATH"

View File

@ -0,0 +1,10 @@
#!/bin/bash
MINDRECORD_PATH=/tmp/sns
rm -f $MINDRECORD_PATH/*
python writer.py --mindrecord_script sns \
--mindrecord_file "$MINDRECORD_PATH/sns" \
--mindrecord_partitions 1 \
--mindrecord_header_size_by_bit 14 \
--mindrecord_page_size_by_bit 15

View File

@ -164,7 +164,7 @@ if __name__ == "__main__":
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
graph_schema = graph_map_schema.get_schema()
graph_schema = graph_map_schema.get_schema
# init writer
writer = init_writer(graph_schema)

View File

@ -0,0 +1,82 @@
# Guideline to Convert Training Data CLUERNER2020 to MindRecord For Bert Fine Tuning
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [How to use the example to process CLUERNER2020](#how-to-use-the-example-to-process-cluerner2020)
- [Download CLUERNER2020 and unzip](#download-cluerner2020-and-unzip)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [CLUERNER2020](https://www.cluebenchmarks.com/introduce.html) training data, generating MindRecord file, and finally used for Bert Fine Tuning progress.
1. run.sh: generate MindRecord entry script
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## How to use the example to process CLUERNER2020
Download CLUERNER2020, convert it to MindRecord, use MindDataset to read MindRecord.
### Download CLUERNER2020 and unzip
1. Download the training data zip.
> [CLUERNER2020 dataset download address](https://www.cluebenchmarks.com/introduce.html) **-> 任务介绍 -> CLUENER 细粒度命名实体识别 -> cluener下载链接**
2. Unzip the training data to dir example/nlp_to_mindrecord/CLUERNER2020/cluener_public.
```
unzip -d {your-mindspore}/example/nlp_to_mindrecord/CLUERNER2020/data/cluener_public cluener_public.zip
```
### Generate MindRecord
1. Run the run.sh script.
```bash
bash run.sh
```
2. Output like this:
```
...
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:12.498.235 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/train.mindrecord'], and the list of index files are: ['data/train.mindrecord.db']
...
[INFO] ME(17603,python):2020-04-28-16:56:13.400.175 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.400.863 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.401.534 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.402.179 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.402.702 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
...
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:13.431.208 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/dev.mindrecord'], and the list of index files are: ['data/dev.mindrecord.db']
```
3. Generate files like this:
```bash
$ ls output/
dev.mindrecord dev.mindrecord.db README.md train.mindrecord train.mindrecord.db
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh
```
2. Output like this:
```
...
example 1340: input_ids: [ 101 3173 1290 4852 7676 3949 122 3299 123 126 3189 4510 8020 6381 5442 7357 2590 3636 8021 7676 3949 4294 1166 6121 3124 1277 6121 3124 7270 2135 3295 5789 3326 123 126 3189 1355 6134 1093 1325 3173 2399 6590 6791 8024 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: label_ids: [ 0 18 19 20 2 4 0 0 0 0 0 0 0 34 36 26 27 28 0 34 35 35 35 35 35 35 35 35 35 36 26 27 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: input_ids: [ 101 1728 711 4293 3868 1168 2190 2150 3791 934 3633 3428 4638 6237 7025 8024 3297 1400 5310 3362 6206 5023 5401 1744 3297 7770 3791 7368 976 1139 1104 2137 511 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: label_ids: [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18 19 19 19 19 20 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
...
```

View File

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: label_ids: {}".format(index, item['label_ids']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
create_dataset('output/train.mindrecord')
create_dataset('output/dev.mindrecord')

View File

@ -0,0 +1 @@
cluener_public

View File

@ -0,0 +1 @@
## The input dataset

View File

@ -0,0 +1 @@
## output dir

View File

@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/train.mindrecord*
rm -f output/dev.mindrecord*
if [ ! -d "../../../third_party/to_mindrecord/CLUERNER2020" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/CLUERNER2020 is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch is not exist."
exit 1
fi
# patch for data_processor_seq.py
patch -p0 -d ../../../third_party/to_mindrecord/CLUERNER2020/ -o data_processor_seq_patched.py < ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq.py failed"
exit 1
fi
# use patched script
python ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq_patched.py \
--vocab_file=../../../third_party/to_mindrecord/CLUERNER2020/vocab.txt \
--label2id_file=../../../third_party/to_mindrecord/CLUERNER2020/label2id.json

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python create_dataset.py

View File

@ -0,0 +1,173 @@
# Guideline to Convert Training Data enwiki to MindRecord For Bert Pre Training
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [How to use the example to process enwiki](#how-to-use-the-example-to-process-enwiki)
- [Download enwiki training data](#download-enwiki-training-data)
- [Process the enwiki](#process-the-enwiki)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [enwiki](https://dumps.wikimedia.org/enwiki) training data, generating MindRecord file, and finally used for Bert network training.
1. run.sh: generate MindRecord entry script.
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## How to use the example to process enwiki
Download enwiki data, process it, convert it to MindRecord, use MindDataset to read MindRecord.
### Download enwiki training data
> [enwiki dataset download address](https://dumps.wikimedia.org/enwiki) **-> 20200501 -> enwiki-20200501-pages-articles-multistream.xml.bz2**
### Process the enwiki
1. Please follow the steps in [process enwiki](https://github.com/mlperf/training/tree/master/language_model/tensorflow/bert)
- All permissions of this step belong to the link address website.
### Generate MindRecord
1. Run the run.sh script.
```
bash run.sh input_dir output_dir vocab_file
```
- input_dir: the directory which contains files like 'part-00251-of-00500'.
- output_dir: which will store the output mindrecord files.
- vocab_file: the vocab file which you can download from other opensource project.
2. The output like this:
```
...
Begin preprocess Wed Jun 10 09:21:23 CST 2020
Begin preprocess input file: /mnt/data/results/part-00000-of-00500
Begin output file: part-00000-of-00500.mindrecord
Total task: 510, processing: 1
Begin preprocess input file: /mnt/data/results/part-00001-of-00500
Begin output file: part-00001-of-00500.mindrecord
Total task: 510, processing: 2
Begin preprocess input file: /mnt/data/results/part-00002-of-00500
Begin output file: part-00002-of-00500.mindrecord
Total task: 510, processing: 3
Begin preprocess input file: /mnt/data/results/part-00003-of-00500
Begin output file: part-00003-of-00500.mindrecord
Total task: 510, processing: 4
Begin preprocess input file: /mnt/data/results/part-00004-of-00500
Begin output file: part-00004-of-00500.mindrecord
Total task: 510, processing: 4
...
```
3. Generate files like this:
```bash
$ ls {your_output_dir}/
part-00000-of-00500.mindrecord part-00000-of-00500.mindrecord.db part-00001-of-00500.mindrecord part-00001-of-00500.mindrecord.db part-00002-of-00500.mindrecord part-00002-of-00500.mindrecord.db ...
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh input_dir
```
- input_dir: the directory which contains mindrecord files.
2. The output like this:
```
...
example 633: input_ids: [ 101 2043 19781 4305 2140 4520 2041 1010 103 2034 2455 2002
7879 2003 1996 2455 1997 103 26378 4160 1012 102 7291 2001
1996 103 1011 2343 1997 6327 1010 3423 1998 103 4262 2005
1996 2118 1997 2329 3996 103 102 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0]
example 633: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 633: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 633: masked_lm_positions: [ 8 17 20 25 33 41 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
example 633: masked_lm_ids: [ 1996 16137 1012 3580 2451 1012 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
example 633: masked_lm_weights: [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.]
example 633: next_sentence_labels: [1]
...
```

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import argparse
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
args = parser.parse_args()
create_dataset(args.input_file)

View File

@ -0,0 +1,133 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]; then
echo "Usage: $0 input_dir output_dir vocab_file"
exit 1
fi
if [ ! -d $1 ]; then
echo "The input dir: $1 is not exist."
exit 1
fi
if [ ! -d $2 ]; then
echo "The output dir: $2 is not exist."
exit 1
fi
rm -fr $2/*.mindrecord*
if [ ! -f $3 ]; then
echo "The vocab file: $3 is not exist."
exit 1
fi
data_dir=$1
output_dir=$2
vocab_file=$3
file_list=()
output_filename=()
file_index=0
function getdir() {
elements=`ls $1`
for element in ${elements[*]};
do
dir_or_file=$1"/"$element
if [ -d $dir_or_file ];
then
getdir $dir_or_file
else
file_list[$file_index]=$dir_or_file
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
mapfile parent_dir < dir_file_list.txt
rm dir_file_list.txt >/dev/null 2>&1
tmp_output_filename=${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
file_index=`expr $file_index + 1`
fi
done
}
getdir "${data_dir}"
# echo "The input files: "${file_list[@]}
# echo "The output files: "${output_filename[@]}
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# get the cpu core count
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
echo "Begin preprocess `date`"
# using patched script to generate mindrecord
file_list_len=`expr ${#file_list[*]} - 1`
for index in $(seq 0 $file_list_len); do
echo "Begin preprocess input file: ${file_list[$index]}"
echo "Begin output file: ${output_filename[$index]}"
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=${file_list[$index]} \
--output_file=${output_dir}/${output_filename[$index]} \
--partition_number=1 \
--vocab_file=${vocab_file} \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 &
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then
while [ 1 ]; do
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
if [ $process_count -gt $process_num ]; then
process_count=$process_num
break;
fi
sleep 2
done
fi
done
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
while [ 1 ]; do
if [ $process_num -eq 0 ]; then
break;
fi
echo "There are still ${process_num} preprocess running ..."
sleep 2
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
done
echo "Preprocess all the data success."
echo "End preprocess `date`"

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]; then
echo "Usage: $0 input_dir"
exit 1
fi
if [ ! -d $1 ]; then
echo "The input dir: $1 is not exist."
exit 1
fi
file_list=()
file_index=0
# get all the mindrecord file from output dir
function getdir() {
elements=`ls $1/part-*.mindrecord`
for element in ${elements[*]};
do
file_list[$file_index]=$element
file_index=`expr $file_index + 1`
done
}
getdir $1
echo "Get all the mindrecord files: "${file_list[*]}
# create dataset for train
python create_dataset.py --input_file ${file_list[*]}

View File

@ -0,0 +1,113 @@
# Guideline to Convert Training Data zhwiki to MindRecord For Bert Pre Training
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [Run simple test](#run-simple-test)
- [How to use the example to process zhwiki](#how-to-use-the-example-to-process-zhwiki)
- [Download zhwiki training data](#download-zhwiki-training-data)
- [Extract the zhwiki](#extract-the-zhwiki)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [zhwiki](https://dumps.wikimedia.org/zhwiki) training data, generating MindRecord file, and finally used for Bert network training.
1. run.sh: generate MindRecord entry script.
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## Run simple test
Follow the step:
```bash
bash run_simple.sh # generate output/simple.mindrecord* by ../../../third_party/to_mindrecord/zhwiki/sample_text.txt
bash run_read_simple.sh # use MindDataset to read output/simple.mindrecord*
```
## How to use the example to process zhwiki
Download zhwiki data, extract it, convert it to MindRecord, use MindDataset to read MindRecord.
### Download zhwiki training data
> [zhwiki dataset download address](https://dumps.wikimedia.org/zhwiki) **-> 20200401 -> zhwiki-20200401-pages-articles-multistream.xml.bz2**
- put the zhwiki-20200401-pages-articles-multistream.xml.bz2 in {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
### Extract the zhwiki
1. Download [wikiextractor](https://github.com/attardi/wikiextractor) script to {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
```
$ ls data/
README.md wikiextractor zhwiki-20200401-pages-articles-multistream.xml.bz2
```
2. Extract the zhwiki.
```python
python data/wikiextractor/WikiExtractor.py data/zhwiki-20200401-pages-articles-multistream.xml.bz2 --processes 4 --templates data/template --bytes 8M --min_text_length 0 --filter_disambig_pages --output data/extract
```
3. Generate like this:
```
$ ls data/extract
AA AB
```
### Generate MindRecord
1. Run the run.sh script.
```
bash run.sh
```
> Caution: This process maybe slow, please wait patiently. If you do not have a machine with enough memory and cpu, it is recommended that you modify the script to generate mindrecord in step by step.
2. The output like this:
```
patching file create_pretraining_data_patched.py (read from create_pretraining_data.py)
Begin preprocess input file: ./data/extract/AA/wiki_00
Begin output file: AAwiki_00.mindrecord
Total task: 5, processing: 1
Begin preprocess input file: ./data/extract/AA/wiki_01
Begin output file: AAwiki_01.mindrecord
Total task: 5, processing: 2
Begin preprocess input file: ./data/extract/AA/wiki_02
Begin output file: AAwiki_02.mindrecord
Total task: 5, processing: 3
Begin preprocess input file: ./data/extract/AB/wiki_02
Begin output file: ABwiki_02.mindrecord
Total task: 5, processing: 4
...
```
3. Generate files like this:
```bash
$ ls output/
AAwiki_00.mindrecord AAwiki_00.mindrecord.db AAwiki_01.mindrecord AAwiki_01.mindrecord.db AAwiki_02.mindrecord AAwiki_02.mindrecord.db ... ABwiki_00.mindrecord ABwiki_00.mindrecord.db ...
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh
```
2. The output like this:
```
...
example 74: input_ids: [ 101 8168 118 12847 8783 9977 15908 117 8256 9245 11643 8168 8847 8588 11575 8154 8228 143 8384 8376 9197 10241 103 10564 11421 8199 12268 112 161 8228 11541 9586 8436 8174 8363 9864 9702 103 103 119 103 9947 10564 103 8436 8806 11479 103 8912 119 103 103 103 12209 8303 103 8757 8824 117 8256 103 8619 8168 11541 102 11684 8196 103 8228 8847 11523 117 9059 9064 12410 8358 8181 10764 117 11167 11706 9920 148 8332 11390 8936 8205 10951 11997 103 8154 117 103 8670 10467 112 161 10951 13139 12413 117 10288 143 10425 8205 152 10795 8472 8196 103 161 12126 9172 13129 12106 8217 8174 12244 8205 143 103 8461 8277 10628 160 8221 119 102]
example 74: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
example 74: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
example 74: masked_lm_positions: [ 6 22 37 38 40 43 47 50 51 52 55 60 67 76 89 92 98 109 120 0]
example 74: masked_lm_ids: [ 8118 8165 8329 8890 8554 8458 119 8850 8565 10392 8174 11467 10291 8181 8549 12718 13139 112 158 0]
example 74: masked_lm_weights: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
example 74: next_sentence_labels: [0]
...
```

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import argparse
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
args = parser.parse_args()
create_dataset(args.input_file)

View File

@ -0,0 +1,3 @@
wikiextractor/
zhwiki-20200401-pages-articles-multistream.xml.bz2
extract/

View File

@ -0,0 +1 @@
## The input dataset

View File

@ -0,0 +1 @@
## Output the mindrecord

View File

@ -0,0 +1,112 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/*.mindrecord*
data_dir="./data/extract"
file_list=()
output_filename=()
file_index=0
function getdir() {
elements=`ls $1`
for element in ${elements[*]};
do
dir_or_file=$1"/"$element
if [ -d $dir_or_file ];
then
getdir $dir_or_file
else
file_list[$file_index]=$dir_or_file
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
mapfile parent_dir < dir_file_list.txt
rm dir_file_list.txt >/dev/null 2>&1
tmp_output_filename=${parent_dir[${#parent_dir[@]}-2]}${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
file_index=`expr $file_index + 1`
fi
done
}
getdir "${data_dir}"
# echo "The input files: "${file_list[@]}
# echo "The output files: "${output_filename[@]}
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# get the cpu core count
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
echo "Begin preprocess `date`"
# using patched script to generate mindrecord
file_list_len=`expr ${#file_list[*]} - 1`
for index in $(seq 0 $file_list_len); do
echo "Begin preprocess input file: ${file_list[$index]}"
echo "Begin output file: ${output_filename[$index]}"
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=${file_list[$index]} \
--output_file=output/${output_filename[$index]} \
--partition_number=1 \
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then
while [ 1 ]; do
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
if [ $process_count -gt $process_num ]; then
process_count=$process_num
break;
fi
sleep 2
done
fi
done
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
while [ 1 ]; do
if [ $process_num -eq 0 ]; then
break;
fi
echo "There are still ${process_num} preprocess running ..."
sleep 2
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
done
echo "Preprocess all the data success."
echo "End preprocess `date`"

View File

@ -0,0 +1,34 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
file_list=()
file_index=0
# get all the mindrecord file from output dir
function getdir() {
elements=`ls $1/[A-Z]*.mindrecord`
for element in ${elements[*]};
do
file_list[$file_index]=$element
file_index=`expr $file_index + 1`
done
}
getdir "./output"
echo "Get all the mindrecord files: "${file_list[*]}
# create dataset for train
python create_dataset.py --input_file ${file_list[*]}

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# create dataset for train
python create_dataset.py --input_file=output/simple.mindrecord0

View File

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/simple.mindrecord*
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# using patched script to generate mindrecord
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=../../../third_party/to_mindrecord/zhwiki/sample_text.txt \
--output_file=output/simple.mindrecord \
--partition_number=4 \
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 # user defined

View File

@ -15,6 +15,7 @@
"""train_imagenet."""
import os
import argparse
import numpy as np
from dataset import create_dataset
from lr_generator import get_lr
from config import config
@ -45,6 +46,7 @@ if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
np.random.seed(1)
if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))

View File

@ -15,6 +15,7 @@
"""train_imagenet."""
import os
import argparse
import numpy as np
from dataset import create_dataset
from lr_generator import get_lr
from config import config
@ -48,6 +49,7 @@ if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
np.random.seed(1)
if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
@ -77,12 +79,12 @@ if __name__ == '__main__':
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype()).to_tensor()
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype()).to_tensor()
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if not config.use_label_smooth:
config.label_smooth_factor = 0.0

View File

@ -15,6 +15,7 @@
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
@ -67,7 +68,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset

View File

@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from model.dataset_helper import DatasetHelper
@ -374,7 +374,6 @@ class Model:
self._train_network.set_broadcast_flag()
# build callback list
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
@ -385,17 +384,17 @@ class Model:
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = list_callback
cb_params.list_callback = callbacks
if dataset_sink_mode:
if context.get_context("mode") == context.PYNATIVE_MODE:
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
@ -408,7 +407,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
iter_first_order = self._frequency - 1
@ -473,7 +472,7 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (_ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
@ -580,7 +579,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
@ -619,7 +618,7 @@ class Model:
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (ListCallback): Executor of callback list. Default: None.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
@ -678,7 +677,6 @@ class Model:
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
list_callback = _build_callbacks(callbacks)
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
@ -691,9 +689,10 @@ class Model:
self._clear_metrics()
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""

View File

@ -151,6 +151,8 @@ class THOR(Optimizer):
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
temp_max = self.mul(temp_max, self.feature_map[i])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if i == 53:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""thor_layer"""
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive
@ -23,7 +25,6 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P
import numpy as np
C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A):
@ -171,7 +172,6 @@ class Conv2d_Thor(_Conv):
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fake_G = Tensor(
np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape))
self.fake_G_inv_max = Tensor(np.zeros([1,]).astype(np.float32))
self.shape = P.Shape()
self.reshape = P.Reshape()
@ -286,7 +286,6 @@ class Conv2d_Thor(_Conv):
matrix_A_inv = self.device_shape_pad(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
self.G_inv_max = self.fake_G_inv_max
self.matrix_A_inv = matrix_A_inv
self.matrix_G_inv = self.fake_G
out = self.conv2d(x, self.weight)
@ -339,15 +338,15 @@ class Dense_Thor(Cell):
self.has_bias = check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

View File

@ -17,6 +17,8 @@ import argparse
import os
import random
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
@ -28,7 +30,6 @@ from model.model_thor import Model
from model.resnet import resnet50
from model.thor import THOR
import numpy as np
from config import config
from crossentropy import CrossEntropy
from dataset_imagenet import create_dataset

View File

@ -1,88 +0,0 @@
# SSD Example
## Description
SSD network based on MobileNetV2, with support for training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Dataset
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool.
```
pip install Cython
pip install pycocotools
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
└─coco2017
├── annotations # annotation jsons
├── train2017 # train dataset
└── val2017 # infer dataset
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
## Running the example
### Training
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
- Stand alone mode
```
python train.py --dataset coco
```
You can run ```python train.py -h``` to get more information.
- Distribute mode
```
sh run_distribute_train.sh 8 150 coco /data/hccl.json
```
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
You will get the loss value of each step as following:
```
epoch: 1 step: 455, loss is 5.8653416
epoch: 2 step: 455, loss is 5.4292373
epoch: 3 step: 455, loss is 5.458992
...
epoch: 148 step: 455, loss is 1.8340507
epoch: 149 step: 455, loss is 2.0876894
epoch: 150 step: 455, loss is 2.239692
```
### Evaluation
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
```
python eval.py --ckpt_path ssd.ckpt --dataset coco
```
You can run ```python eval.py -h``` to get more information.

View File

@ -1,64 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for SSD models."""
class ConfigSSD:
"""
Config parameters for SSD.
Examples:
ConfigSSD().
"""
IMG_SHAPE = [300, 300]
NUM_SSD_BOXES = 1917
NEG_PRE_POSITIVE = 3
MATCH_THRESHOLD = 0.5
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
SCALES = [21, 45, 99, 153, 207, 261, 315]
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
STEPS = (16, 32, 64, 100, 150, 300)
PRIOR_SCALING = (0.1, 0.2)
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
MINDRECORD_DIR = "MindRecord_COCO"
COCO_ROOT = "coco2017"
TRAIN_DATA_TYPE = "train2017"
VAL_DATA_TYPE = "val2017"
INSTANCES_SET = "annotations/instances_{}.json"
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
NUM_CLASSES = len(COCO_CLASSES)

View File

@ -1,375 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SSD dataset"""
from __future__ import division
import os
import math
import itertools as it
import numpy as np
import cv2
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as C
from mindspore.mindrecord import FileWriter
from config import ConfigSSD
config = ConfigSSD()
class GeneratDefaultBoxes():
"""
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2].
"""
def __init__(self):
fk = config.IMG_SHAPE[0] / np.array(config.STEPS)
self.default_boxes = []
for idex, feature_size in enumerate(config.FEATURE_SIZE):
sk1 = config.SCALES[idex] / config.IMG_SHAPE[0]
sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0]
sk3 = math.sqrt(sk1 * sk2)
if config.NUM_DEFAULT[idex] == 3:
all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)]
else:
all_sizes = [(sk1, sk1), (sk3, sk3)]
for aspect_ratio in config.ASPECT_RATIOS[idex]:
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
all_sizes.append((w, h))
all_sizes.append((h, w))
assert len(all_sizes) == config.NUM_DEFAULT[idex]
for i, j in it.product(range(feature_size), repeat=2):
for w, h in all_sizes:
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)]
self.default_boxes.append(box)
def to_ltrb(cx, cy, w, h):
return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
# For IoU calculation
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
self.default_boxes = np.array(self.default_boxes, dtype='float32')
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
default_boxes = GeneratDefaultBoxes().default_boxes
x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.MATCH_THRESHOLD
def ssd_bboxes_encode(boxes):
"""
Labels anchors with ground truth inputs.
Args:
boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls].
Returns:
gt_loc: location ground truth with shape [num_anchors, 4].
gt_label: class ground truth with shape [num_anchors, 1].
num_matched_boxes: number of positives in an image.
"""
def jaccard_with_anchors(bbox):
"""Compute jaccard score a box and the anchors."""
# Intersection bbox and volume.
xmin = np.maximum(x1, bbox[0])
ymin = np.maximum(y1, bbox[1])
xmax = np.minimum(x2, bbox[2])
ymax = np.minimum(y2, bbox[3])
w = np.maximum(xmax - xmin, 0.)
h = np.maximum(ymax - ymin, 0.)
# Volumes.
inter_vol = h * w
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
jaccard = inter_vol / union_vol
return np.squeeze(jaccard)
pre_scores = np.zeros((config.NUM_SSD_BOXES), dtype=np.float32)
t_boxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
t_label = np.zeros((config.NUM_SSD_BOXES), dtype=np.int64)
for bbox in boxes:
label = int(bbox[4])
scores = jaccard_with_anchors(bbox)
mask = (scores > matching_threshold)
if not np.any(mask):
mask[np.argmax(scores)] = True
mask = mask & (scores > pre_scores)
pre_scores = np.maximum(pre_scores, scores)
t_label = mask * label + (1 - mask) * t_label
for i in range(4):
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
index = np.nonzero(t_label)
# Transform to ltrb.
bboxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
# Encode features.
bboxes_t = bboxes[index]
default_boxes_t = default_boxes[index]
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.PRIOR_SCALING[0])
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1]
bboxes[index] = bboxes_t
num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match_num
def ssd_bboxes_decode(boxes, index):
"""Decode predict boxes to [x, y, w, h]"""
boxes_t = boxes[index]
default_boxes_t = default_boxes[index]
boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4]
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
return bboxes
def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
def _rand(a=0., b=1.):
"""Generate random."""
return np.random.rand() * (b - a) + a
def _infer_data(image, input_shape, box):
img_h, img_w, _ = image.shape
input_h, input_w = input_shape
scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h))
nw = int(img_w * scale)
nh = int(img_h * scale)
image = cv2.resize(image, (nw, nh))
new_image = np.zeros((input_h, input_w, 3), np.float32)
dh = (input_h - nh) // 2
dw = (input_w - nw) // 2
new_image[dh: (nh + dh), dw: (nw + dw), :] = image
image = new_image
#When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
box = box.astype(np.float32)
box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w
box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h
return image, np.array((img_h, img_w), np.float32), box
def _data_aug(image, box, is_training, image_size=(300, 300)):
"""Data augmentation function."""
ih, iw, _ = image.shape
w, h = image_size
if not is_training:
return _infer_data(image, image_size, box)
# Random settings
scale_w = _rand(0.75, 1.25)
scale_h = _rand(0.75, 1.25)
flip = _rand() < .5
nw = iw * scale_w
nh = ih * scale_h
scale = min(w / nw, h / nh)
nw = int(scale * nw)
nh = int(scale * nh)
# Resize image
image = cv2.resize(image, (nw, nh))
# place image
new_image = np.zeros((h, w, 3), dtype=np.float32)
dw = (w - nw) // 2
dh = (h - nh) // 2
new_image[dh:dh + nh, dw:dw + nw, :] = image
image = new_image
# Flip image or not
if flip:
image = cv2.flip(image, 1, dst=None)
# Convert image to gray or not
gray = _rand() < .25
if gray:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# When the channels of image is 1
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = np.concatenate([image, image, image], axis=-1)
box = box.astype(np.float32)
# Transform box with shape[x1, y1, x2, y2].
box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w
box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h
if flip:
box[:, [0, 2]] = 1 - box[:, [2, 0]]
box, label, num_match_num = ssd_bboxes_encode(box)
return image, box, label, num_match_num
return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE)
def create_coco_label(is_training):
"""Get image path and annotation from COCO."""
from pycocotools.coco import COCO
coco_root = config.COCO_ROOT
data_type = config.VAL_DATA_TYPE
if is_training:
data_type = config.TRAIN_DATA_TYPE
#Classes need to train or test.
train_cls = config.COCO_CLASSES
train_cls_dict = {}
for i, cls in enumerate(train_cls):
train_cls_dict[cls] = i
anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type))
coco = COCO(anno_json)
classs_dict = {}
cat_ids = coco.loadCats(coco.getCatIds())
for cat in cat_ids:
classs_dict[cat["id"]] = cat["name"]
image_ids = coco.getImgIds()
image_files = []
image_anno_dict = {}
for img_id in image_ids:
image_info = coco.loadImgs(img_id)
file_name = image_info[0]["file_name"]
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = coco.loadAnns(anno_ids)
image_path = os.path.join(coco_root, data_type, file_name)
annos = []
for label in anno:
bbox = label["bbox"]
class_name = classs_dict[label["category_id"]]
if class_name in train_cls:
x_min, x_max = bbox[0], bbox[0] + bbox[2]
y_min, y_max = bbox[1], bbox[1] + bbox[3]
annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]])
if len(annos) >= 1:
image_files.append(image_path)
image_anno_dict[image_path] = np.array(annos)
return image_files, image_anno_dict
def anno_parser(annos_str):
"""Parse annotation from string to list."""
annos = []
for anno_str in annos_str:
anno = list(map(int, anno_str.strip().split(',')))
annos.append(anno)
return annos
def filter_valid_data(image_dir, anno_path):
"""Filter valid image file, which both in image_dir and anno_path."""
image_files = []
image_anno_dict = {}
if not os.path.isdir(image_dir):
raise RuntimeError("Path given is not valid.")
if not os.path.isfile(anno_path):
raise RuntimeError("Annotation file is not valid.")
with open(anno_path, "rb") as f:
lines = f.readlines()
for line in lines:
line_str = line.decode("utf-8").strip()
line_split = str(line_str).split(' ')
file_name = line_split[0]
image_path = os.path.join(image_dir, file_name)
if os.path.isfile(image_path):
image_anno_dict[image_path] = anno_parser(line_split[1:])
image_files.append(image_path)
return image_files, image_anno_dict
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
"""Create MindRecord file."""
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
if dataset == "coco":
image_files, image_anno_dict = create_coco_label(is_training)
else:
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
ssd_json = {
"image": {"type": "bytes"},
"annotation": {"type": "int32", "shape": [-1, 5]},
}
writer.add_schema(ssd_json, "ssd_json")
for image_name in image_files:
with open(image_name, 'rb') as f:
img = f.read()
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
row = {"image": img, "annotation": annos}
writer.write_raw_data([row])
writer.commit()
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=4):
"""Creatr SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
if is_training:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "box", "label", "num_match_num"],
columns_order=["image", "box", "label", "num_match_num"],
operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True,
num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
hwc_to_chw = C.HWC2CHW()
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
columns_order=["image", "image_shape", "annotation"],
operations=compose_map_func)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds

View File

@ -1,206 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from config import ConfigSSD
from dataset import ssd_bboxes_decode
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
bbox_pred = np.expand_dims(bbox_pred, axis=0)
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
pred_area = pred_w * pred_h
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
gt_area = gt_w * gt_h
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)
intersection_area = iw * ih
union_area = pred_area + gt_area - intersection_area
union_area = np.maximum(union_area, np.finfo(float).eps)
iou = intersection_area * 1. / union_area
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def calc_ap(recall, precision):
"""Calculate AP."""
correct_recall = np.concatenate(([0.], recall, [1.]))
correct_precision = np.concatenate(([0.], precision, [0.]))
for i in range(correct_recall.size - 1, 0, -1):
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
return ap
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
config = ConfigSSD()
num_classes = config.NUM_CLASSES
all_detections = [None for i in range(num_classes)]
all_pred_scores = [None for i in range(num_classes)]
all_annotations = [None for i in range(num_classes)]
average_precisions = {}
num = [0 for i in range(num_classes)]
accurate_num = [0 for i in range(num_classes)]
for sample in pred_data:
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
annotation = np.squeeze(annotation, axis=0)
pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
boxes_scores = boxes_scores[index]
pred_labels = pred_labels[index]
top_k = 50
for c in range(1, num_classes):
if len(pred_labels) >= 1:
class_box_scores = boxes_scores[pred_labels == c]
class_boxes = pred_boxes[pred_labels == c]
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
cmask = class_box_scores > 0.5
class_boxes = class_boxes[cmask]
class_box_scores = class_box_scores[cmask]
all_detections[c] = class_boxes
all_pred_scores[c] = class_box_scores
for c in range(1, num_classes):
if len(annotation) >= 1:
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
for c in range(1, num_classes):
false_positives = np.zeros((0,))
true_positives = np.zeros((0,))
scores = np.zeros((0,))
num_annotations = 0.0
annotations = all_annotations[c]
num_annotations += annotations.shape[0]
detections = all_detections[c]
pred_scores = all_pred_scores[c]
for index, detection in enumerate(detections):
scores = np.append(scores, pred_scores[index])
if len(annotations) >= 1:
IoUs = calc_iou(detection, annotations)
assigned_anno = np.argmax(IoUs)
max_overlap = IoUs[assigned_anno]
if max_overlap >= 0.5:
false_positives = np.append(false_positives, 0)
true_positives = np.append(true_positives, 1)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
if num_annotations == 0:
if c not in average_precisions.keys():
average_precisions[c] = 0
continue
accurate_num[c] = 1
indices = np.argsort(-scores)
false_positives = false_positives[indices]
true_positives = true_positives[indices]
false_positives = np.cumsum(false_positives)
true_positives = np.cumsum(true_positives)
recall = true_positives * 1. / num_annotations
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
average_precision = calc_ap(recall, precision)
if c not in average_precisions.keys():
average_precisions[c] = average_precision
else:
average_precisions[c] += average_precision
num[c] += 1
count = 0
for key in average_precisions:
if num[key] != 0:
count += (average_precisions[key] / num[key])
mAP = count * 1. / accurate_num.count(1)
return mAP

@ -1 +1 @@
Subproject commit c27e428e9698dd4f9b198008596676bc2d1b49aa
Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9

44
include/inference.h Normal file
View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_MS_SESSION_H
#define MINDSPORE_INCLUDE_MS_SESSION_H
#include <memory>
#include <vector>
#include <string>
#include "include/ms_tensor.h"
namespace mindspore {
class FuncGraph;
namespace inference {
class MS_API MSSession {
public:
MSSession() = default;
static std::shared_ptr<MSSession> CreateSession(const std::string &device, uint32_t device_id);
virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0;
virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) = 0;
};
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
void MS_API ExitInference();
} // namespace inference
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_MS_SESSION_H

69
include/ms_tensor.h Normal file
View File

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_
#define MINDSPORE_INCLUDE_MS_TENSOR_H_
#include <utility>
#include <vector>
#include <memory>
#include "ir/dtype/type_id.h"
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
namespace inference {
class MS_API MSTensor {
public:
MSTensor() = default;
// brief Create a MSTensor pointer.
//
// param data_type DataTypeId of tensor to be created.
// param shape Shape of tensor to be created.
// return MSTensor pointer.
static MSTensor *CreateTensor(TypeId data_type, const std::vector<int> &shape);
~MSTensor() = default;
virtual TypeId data_type() const = 0;
virtual TypeId set_data_type(const TypeId data_type) = 0;
virtual std::vector<int> shape() const = 0;
virtual size_t set_shape(const std::vector<int> &shape) = 0;
virtual int DimensionSize(size_t index) const = 0;
// brief Get number of element in MSTensor.
//
// return Number of element in MSTensor.
virtual int ElementsNum() const = 0;
virtual std::size_t hash() const = 0;
// brief Get byte size of data in MSTensor.
//
// return Byte size of data in MSTensor.
virtual size_t Size() const = 0;
// brief Get pointer of data in MSTensor.
//
// The data pointer can be used to both write or read data in MSTensor.
//
// return A pointer points to data in MSTensor.
virtual void *MutableData() const = 0;
};
using MultiTensor = std::vector<std::shared_ptr<inference::MSTensor>>;
} // namespace inference
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_

View File

@ -35,3 +35,5 @@ from .logical_not import LogicalNot, gpu_schedule_LogicalNot
from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd
from .sub import Sub, gpu_schedule_Sub
from .less_equal import LessEqual, gpu_schedule_LessEqual
from .notequal import NotEqual, gpu_schedule_NotEqual
from .greater_equal import GreaterEqual, gpu_schedule_GreaterEqual

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""greater_equal"""
import _akg.tvm
from _akg.ops.math import greater_equal
from _akg.topi.generic import schedule_elemwise
def GreaterEqual(x, y):
"""GreaterEqual."""
return greater_equal.greater_equal(x, y)
def gpu_schedule_GreaterEqual(outs):
"""
GPU schedule for GreaterEqual.
Args:
outs (tvm.tensor.Tensor): Outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""notequal"""
import _akg.tvm
from _akg.ops.math import notequal
from _akg.topi.generic import schedule_elemwise
def NotEqual(x, y):
"""notequal."""
return notequal.notequal(x, y)
def gpu_schedule_NotEqual(outs):
"""
gpu schedule for NotEqual.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = _akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with _akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: greaterequal"""
import _akg.tvm
import _akg.topi
from _akg.utils.dsl_create import produce_shapes
from _akg.utils import validation_check as vc_util
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
def greater_equal(input1, input2):
"""
Check whether input1 greaterquals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 greaterquals to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get greaterquals compute
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
input1_bro = _akg.topi.broadcast_to(input1, shape)
input2_bro = _akg.topi.broadcast_to(input2, shape)
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] >= input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: notequal"""
import _akg.tvm
import _akg.topi
from _akg.utils.dsl_create import produce_shapes
from _akg.utils import validation_check as vc_util
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
def notequal(input1, input2):
"""
check whether input1 notequals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 notequal to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get notequal compute
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
input1_bro = _akg.topi.broadcast_to(input1, shape)
input2_bro = _akg.topi.broadcast_to(input2, shape)
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] != input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""builtin_operations"""
import functools
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
@ -114,6 +113,24 @@ def bool_or(x, y):
"""Implement `bool_or`."""
return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
def make_list(*xs):
"""Implement `make_list`."""
@ -124,17 +141,8 @@ def list_len(x):
"""Implement `list_len`."""
return len(x)
# only used in PyNative mode
def partial(*args):
"""Implement `partial`."""
func = args[0].__call__
partial_func = functools.partial(func, *args[1:])
return partial_func
# only used in PyNative mode
def depend(value, expr):
def Depend(value, expr):
"""Implement `Depend`."""
return value
# only used in PyNative mode

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing akg compile with json"""
import sys
def run_compiler(op_json):
"""
Run AKG compiler to compile op with subprocess, if this process of
compilation failed, an exception will be raised
Args:
op_json (str): json string of the op
Returns:
None
"""
p = __import__("akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
res = func(op_json)
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
run_compiler(sys.argv[1])

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import os
import subprocess
import sys
from multiprocessing import Pool, cpu_count
def _compile_akg_task(*json_strs):
"""
compile func called in single process
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
"""
akg_compiler = os.path.join(os.path.split(
os.path.realpath(__file__))[0], "compiler.py")
for json_str in json_strs:
res = subprocess.run(
[sys.executable, akg_compiler, json_str], text=True)
if res.returncode != 0:
raise ValueError("Failed, args: {}!".format(json_str))
def compile_akg_kernel_parallel(json_infos, process, waitime):
"""
compile kernel use multi processes
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
Returns:
True for all compile success, False for some failed.
"""
if not isinstance(json_infos, list):
raise ValueError("json_infos must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
if process == 0 and json_infos:
process = 1
cpu_proc_num = cpu_count()
max_proc_num = 16
process = min([cpu_proc_num, max_proc_num, process])
args = [[] for _ in range(process)]
for p, info in enumerate(json_infos):
args[p % process].append(info)
with Pool(processes=process) as pool:
res = pool.starmap_async(_compile_akg_task, args)
res.get(timeout=waitime)
return True

View File

@ -1,107 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import json
import math
import os
import subprocess
import sys
from multiprocessing import Pool
def _compiletask(platform, *jsons):
"""
compile func called in single process
Parameters:
platform: str. AKG platform or TBE platform
*jsons: str. json str contain kernel info, suitable for json compile
api
"""
if platform == "AKG":
p = __import__("_akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
for json_item in jsons:
res = func(json_item)
if not res:
raise ValueError("Compile error")
if platform == "TBE":
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
for json_item in jsons:
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
if res.returncode != 0:
raise ValueError("Tbe compile error")
def compilekernelparallel(jsons, process, waitime):
"""
compile kernel use multi processes
Parameters:
jsons: list. json str list contain kernel info
process: int. processes num
waittime: int. max time the function blocked
"""
if not isinstance(jsons, list):
raise ValueError("jsons must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
jsons_akg = []
jsons_tbe = []
for json_ in jsons:
j = json.loads(json_)
if j["platform"] == "TBE":
jsons_tbe.append(json_)
continue
if j["platform"] == "AKG":
jsons_akg.append(json_)
continue
raise RuntimeError(
"not support this platform {0}".format(j["platform"]))
if jsons_akg:
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
else:
process_akg = 0
if process_akg == 0 and jsons_akg:
process_akg = 1
process_tbe = process-process_akg
if process_tbe == 0 and jsons_tbe:
process_tbe = 1
raise RuntimeWarning("we add a process for compile more operator")
args = [[] for _ in range(process_akg+process_tbe)]
args_lens = len(args)
for p in range(args_lens):
if p < process_tbe:
args[p].append("TBE")
else:
args[p].append("AKG")
jsons_tbe_lens = len(jsons_tbe)
for p in range(jsons_tbe_lens):
args[p % process_tbe].append(jsons_tbe[p])
jsons_akg_lens = len(jsons_akg)
for p in range(jsons_akg_lens):
args[process-p % process_akg-1].append(jsons_akg[p])
for p in range(args_lens):
args[p] = tuple(args[p])
with Pool(processes=process) as pool:
res = pool.starmap_async(_compiletask, args)
res.get(timeout=waitime)
return True

View File

@ -15,13 +15,6 @@
"""tbe common"""
import json
import os
from attrdict import AttrDict
class ParamType(AttrDict):
Required = "required"
Dynamic = "dynamic"
Optional = "optional"
class TBEException(Exception):
"""tbe exception class"""
@ -112,7 +105,7 @@ def get_input_output(io_info, args):
if len(item) > 1:
arg.append(info)
else:
if info['param_type'] == ParamType.Dynamic:
if info['param_type'] == 'dynamic':
arg.append(info)
args.append(arg)
else:

View File

@ -28,7 +28,8 @@ build_in_impl_path = get_build_in_impl_path()
# op function list
op_build = "compile"
op_pre_build = "pre_build"
fusion_pattern_start_flag = "fusion_pattern_start"
fusion_pattern_end_flag = "fusion_pattern_end"
def _initialize(impl_path):
"""Initialize"""
@ -42,7 +43,6 @@ def _initialize(impl_path):
sys.path.insert(0, op_module_name)
def build_op(build_type, json_str):
"""
call op functions with function name and input args json_str
@ -108,7 +108,7 @@ def build_op(build_type, json_str):
# pre build
if build_type == op_pre_build:
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name)
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
# disable only pattern configuration
op_build_cfg_en()
return get_op_pattern()
@ -159,11 +159,14 @@ def compile_with_json(json_str):
json_info = json.loads(json_str)
if "fusion_op" in json_info:
ret = compile_fusion_op(json_str)
elif "compile_type" in json_info:
ret = build_op(op_pre_build, json_str)
else:
ret = build_op(op_build, json_str)
return ret
if __name__ == "__main__":
in_args = sys.stdin.readline()
compile_with_json(in_args)
result = compile_with_json(in_args)
sys.stdout.write(fusion_pattern_start_flag + str(result) + fusion_pattern_end_flag)
sys.stdout.flush()

View File

@ -75,7 +75,6 @@ def check_supported(op_json: str):
return ret
def run_compiler(op_json):
"""
run compiler to compile op with subprocess
@ -88,15 +87,16 @@ def run_compiler(op_json):
"""
try:
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
text=True, capture_output=True, check=True)
return "Success", "Success"
completed_object = subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
text=True, capture_output=True, check=True)
if completed_object:
out = completed_object.stdout
return "Success", out
except subprocess.TimeoutExpired:
tb = traceback.format_exc()
return "TBEException", "CompileTimeOut: " + tb + "\ninput_args: " + op_json
return "TBEException", "PreCompileTimeOut: " + tb + "\ninput_args: " + op_json
except subprocess.CalledProcessError as e:
return "TBEException", "CompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
class CompilerPool:
"""compiler pool"""
@ -154,11 +154,11 @@ class CompilerPool:
task_id, task_future = self.__running_tasks.pop(0)
ret_type, result = task_future.get(330)
if ret_type == "Success":
ret = task_id, "Success"
ret = task_id, "Success", result
elif ret_type in ("Exception", "TBEException"):
ret = task_id, ret_type + ":" + result
ret = task_id, ret_type + ":" + result, "_"
else:
ret = task_id, "Exception: Not support return type:" + str(ret_type)
ret = task_id, "Exception: Not support return type:" + str(ret_type), "_"
return ret
def reset_task_info(self):

View File

@ -19,14 +19,15 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_obj_instance, generate_scope,
get_bprop_method_of_class, get_class_instance_type,
get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
get_default_input, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']
'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj']

View File

@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice, Ellipsis_
# define return value
RET_SUCCESS = 0
@ -70,14 +69,9 @@ parse_expr_statement_white_list = (
"append",
)
def create_ellipsis_obj():
"""Create Slice object"""
return Ellipsis_()
def create_slice_obj(start, end, step):
"""Create Slice object"""
return Slice(start, end, step)
"""Create slice object"""
return slice(start, end, step)
def parse_cb(func, parse_method=None):
@ -209,6 +203,14 @@ def get_object_key(obj):
obj_id = instance_id + obj_id
return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
if isinstance(obj, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args = tuple(convert(x) for x in obj)
return args
return obj
def is_class_member(node):
"""Check the attr is class member variable."""
@ -221,6 +223,9 @@ def is_class_member(node):
return True
return False
def get_obj_id(obj):
"""Get the obj id."""
return str(id(obj))
def get_obj_type(obj):
"""Get the obj type."""

View File

@ -126,7 +126,7 @@ convert_object_map = {
T.make_list: F.make_list,
T.make_slice: F.make_slice,
T.range: F.make_range,
T.while_cond: M.while_cond,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,

View File

@ -16,8 +16,10 @@
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like
from ...ops.composite.base import _append
@ -102,11 +104,44 @@ def bool_(x):
return x.__bool__()
def tensor_bool(x):
"""return immedate x, x is a tensor of bool value"""
def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond:
return F.cast(x, mstype.bool_)
return x
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
if shp in ((), (1,)):
return True
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
@constexpr
def const_tensor_to_bool(x):
"""convert bool tensor to bool condition"""
if x is None:
raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
if x.shape == ():
value = bool(x)
else:
value = bool(x[0])
return value
def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond and F.isconstant(x):
return const_tensor_to_bool(x)
return F.cast(x, mstype.bool_)
def and_(x, y):
"""Implementation of `and` (`&`)."""
return x.__and__(y)

View File

@ -91,3 +91,7 @@ def to_array(x): # pragma: no cover
def not_contains(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')

View File

@ -19,7 +19,6 @@ import logging
import os
import inspect
from functools import wraps
from dataclasses import dataclass
def cal_sha256(file_path):
@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None):
if fn is not None:
return wrap_cell(fn)
return wrap_cell
@dataclass
class Slice:
"""
Slice class
"""
start: int
end: int
step: int
@dataclass
class Ellipsis_:
"""
Ellipsis class
"""

View File

@ -8,6 +8,10 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
add_compile_definitions(BUILDING_DLL)
endif()
if (ENABLE_MPI)
add_compile_definitions(ENABLE_MPI)
endif ()
if(ENABLE_GPU)
find_package(CUDA REQUIRED)
find_package(Threads)
@ -35,7 +39,7 @@ if(ENABLE_GPU)
"device/gpu/*.cu"
"kernel/gpu/*.cu"
"kernel/akg/gpu/*.cc"
"kernel/akg/akgkernelbuild.cc"
"kernel/akg/akg_kernel_build.cc"
"kernel/akg/akg_kernel_attrs_process.cc"
)
@ -75,7 +79,9 @@ if (ENABLE_DUMP_PROTO)
file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"utils/anf_ir.proto"
"utils/summary.proto"
"utils/lineage.proto"
"utils/checkpoint.proto"
"utils/print.proto"
)
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
@ -120,7 +126,11 @@ endforeach ()
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
target_link_libraries(mindspore proto_input)
target_link_libraries(mindspore securec mindspore::flatbuffers)
if (ENABLE_CPU AND ENABLE_MPI)
target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi)
else ()
target_link_libraries(mindspore securec mindspore::flatbuffers)
endif ()
if (NOT WIN32)
target_link_libraries(mindspore dl)
endif()
@ -227,3 +237,29 @@ if (ENABLE_MINDDATA)
add_subdirectory(mindrecord)
add_subdirectory(dataset)
endif ()
# build inference
set(LOAD_ONNX_SRC
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
)
add_library(inference SHARED
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
${LOAD_ONNX_SRC}
)
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
if (ENABLE_CPU)
target_link_libraries(inference PRIVATE mindspore::dnnl mindspore::mkldnn)
endif ()
if (USE_GLOG)
target_link_libraries(inference PRIVATE mindspore::glog)
else()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
endif()

View File

@ -14,11 +14,9 @@
* limitations under the License.
*/
#include "common/trans.h"
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include "./securec.h"
#include "common/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel.h"
@ -29,34 +27,7 @@
namespace mindspore {
namespace trans {
namespace {
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(4, 1);
switch (shape.size()) {
case 0:
return shape_4d;
case 1:
shape_4d[1] = shape[0];
break;
case 2:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
break;
case 3:
shape_4d[1] = shape[0];
shape_4d[2] = shape[1];
shape_4d[3] = shape[2];
break;
case 4:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
}
return shape_4d;
}
} // namespace
const size_t kNchwDims = 4;
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
template <typename T>
T DivCeil(T n1, T n2) {
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
if (n2 != 0) {
return (n1 - 1) / n2 + 1;
}
return 0;
}
enum DataTypeTransMode {
@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
}
size_t ShapeSize(const std::vector<size_t> &shape) {
size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
return product;
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
}
size_t TypeIdSize(const TypeId data_type) {
@ -239,27 +212,177 @@ size_t TypeIdSize(const TypeId data_type) {
return unsupported_type_error;
}
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Host shape dims shoud be 4";
return false;
}
return true;
}
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
return shape;
}
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[kN]);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kC]);
return device_shape;
}
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kC]);
device_shape.push_back(shape[kN]);
return device_shape;
}
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
const size_t C0 = kCubeSize;
device_shape.push_back(shape[kN]);
device_shape.push_back(C1);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(shape[kN]);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
const size_t c0 = 4;
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
auto no = DivCeil(shape.at(kN), kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
const size_t C1 = 1;
const size_t C0 = 4;
device_shape.push_back(shape[kN]);
device_shape.push_back(C1);
device_shape.push_back(shape[kH]);
device_shape.push_back(shape[kW]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < kNdhwc) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
return shape;
}
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(kNchwDims, 1);
switch (shape.size()) {
case 0:
return shape_4d;
case 1:
shape_4d[kC] = shape[kN];
break;
case 2:
shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC];
break;
case 3:
shape_4d[kC] = shape[kN];
shape_4d[kH] = shape[kC];
shape_4d[kW] = shape[kH];
break;
case 4:
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
}
return shape_4d;
}
} // namespace
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
return false;
} else if (shape_size < 4) {
} else if (shape_size < kNchwDims) {
return true;
}
return false;
}
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
std::vector<int> shape;
std::vector<size_t> host_shape;
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
auto tensor = node_value->cast<tensor::TensorPtr>();
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
}
auto shape_temp = tensor->shape();
(void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
@ -280,134 +403,13 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
return PaddingShapeTo4dByDefault(shape);
}
std::vector<size_t> shape_4d(4, 1);
std::vector<size_t> shape_4d(kNchwDims, 1);
for (size_t index = 0; index < padding_axis.size(); index++) {
shape_4d[padding_axis[index]] = shape[index];
}
return shape_4d;
}
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != 4) {
MS_LOG(ERROR) << "Host shape dims shoud be 4";
return false;
}
return true;
}
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
return shape;
}
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[0]);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
return device_shape;
}
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[1]);
device_shape.push_back(shape[0]);
return device_shape;
}
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
device_shape.push_back(cout16 / kCubeSize);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
size_t C0 = kCubeSize;
device_shape.push_back(shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(shape[0]);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t c0 = 4;
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
auto no = DivCeil(shape.at(0), kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = 1;
size_t C0 = 4;
device_shape.push_back(shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(C0);
return device_shape;
}
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
if (shape.size() < 5) {
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
}
return shape;
}
} // namespace
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
@ -426,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
auto temp_shape = shape;
std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else {
@ -439,7 +445,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != 4) {
if (shape.size() != kNchwDims) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}
@ -455,6 +461,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
MS_EXCEPTION_IF_NULL(size);
MS_EXCEPTION_IF_NULL(total_size);
*size = TypeIdSize(args.src_data_type);
if (*size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
@ -540,10 +548,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
@ -572,10 +580,10 @@ bool ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
for (size_t ni = 0; ni < n; ni++) {
for (size_t ci = 0; ci < c; ci++) {
for (size_t hi = 0; hi < h; hi++) {
@ -602,32 +610,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
size_t c0 = CubeSizeByType(args.src_data_type);
auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t hwc0 = hw * c0;
size_t nchw = n * chw;
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto chw = c * hw;
auto hwc0 = hw * c0;
auto nchw = n * chw;
size_t hf_cnt = DivCeil(n, kCubeSize);
size_t vf_cnt = c1 * hw;
size_t fractal_ele_cnt = c0 * kCubeSize;
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
size_t dst_size = total_ele_cnt * size;
auto hf_cnt = DivCeil(n, kCubeSize);
auto vf_cnt = c1 * hw;
auto fractal_ele_cnt = c0 * kCubeSize;
auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
auto dst_size = total_ele_cnt * size;
if (dst_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size."
<< "dst size is :" << dst_size << "device size is :" << args.device_size;
@ -647,7 +655,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
auto src_ni = hfi * kCubeSize + col;
auto src_idx = src_row_offset + chw * col;
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false;
auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
@ -663,12 +671,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t total_size = ShapeSize(args.device_shape) * size;
auto total_size = ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
@ -677,18 +685,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
auto n0 = args.device_shape.at(1);
auto ni = args.device_shape.at(2);
auto c0 = args.device_shape.at(3);
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
size_t nc = ni * n0;
size_t ncc0 = nc * c0;
size_t wncc0 = w * ncc0;
size_t hwncc0 = h * wncc0;
size_t hw = h * w;
size_t chw = c * hw;
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto nc = ni * n0;
auto ncc0 = nc * c0;
auto wncc0 = w * ncc0;
auto hwncc0 = h * wncc0;
auto hw = h * w;
auto chw = c * hw;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * chw;
@ -720,20 +726,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
size_t cube = kCubeSize;
size_t n = args.host_shape[0];
size_t c = args.host_shape[1];
size_t h = args.host_shape[2];
size_t w = args.host_shape[3];
size_t c0 = 4;
size_t c1 = DivCeil(c, c0);
size_t hwc0 = h * w * c0;
size_t hwc = h * w * c;
size_t nhwc = n * h * w * c;
size_t n_cnt = DivCeil(n, cube);
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
auto cube = kCubeSize;
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const size_t c0 = 4;
auto c1 = DivCeil(c, c0);
auto hwc0 = h * w * c0;
auto hwc = h * w * c;
auto nhwc = n * h * w * c;
auto n_cnt = DivCeil(n, cube);
auto v_cnt = DivCeil(h * w * c0 * c1, cube);
size_t dst_idx = 0;
for (size_t vi = 0; vi < v_cnt; vi++) {
@ -929,7 +933,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
@ -940,20 +944,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
size_t c0 = CubeSizeByType(args.src_data_type);
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto c0 = CubeSizeByType(args.src_data_type);
if (c0 < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t c1hwc0 = c1 * hw * c0;
size_t wc0 = w * c0;
if (args.device_format == kOpFormat_NC1HWC0_C04) {
c0 = 4;
}
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto chw = c * hw;
auto c1hwc0 = c1 * hw * c0;
auto wc0 = w * c0;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * c1hwc0;
@ -967,7 +974,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
size_t dst_idx = c0_idx + w_head_addr;
size_t c_idx = c0_idx + c1_idx * c0;
size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
auto pad_zero = (c_idx < c) ? false : true;
auto pad_zero = c_idx >= c;
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
@ -984,29 +991,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
size_t size = TypeIdSize(args.src_data_type);
auto size = TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t total_size = ShapeSize(args.device_shape) * size;
auto total_size = ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
auto c1 = args.device_shape[1];
auto c0 = args.device_shape[4];
size_t hw = h * w;
size_t chw = c * hw;
size_t wc0 = w * c0;
size_t hwc0 = h * wc0;
size_t c1hwc0 = c1 * hwc0;
auto hw = h * w;
auto chw = c * hw;
auto wc0 = w * c0;
auto hwc0 = h * wc0;
auto c1hwc0 = c1 * hwc0;
for (size_t n_idx = 0; n_idx < n; n_idx++) {
size_t n_head_addr = n_idx * chw;
@ -1037,13 +1044,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const int co_idx = 4;
const int c0_idx = 5;
auto c1 = args.device_shape[0];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
auto co = args.device_shape[co_idx];
auto c0 = args.device_shape[c0_idx];
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
@ -1055,7 +1064,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
co_i * c0 + c0_i;
size_t c_i = c0_i + c1_i * c0;
size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
auto pad_zero = (c_i < c && c0_i == co_i) ? false : true;
auto pad_zero = !(c_i < c && c0_i == co_i);
SetData(size, pad_zero, src_idx, dst_idx, args, result);
}
}
@ -1076,12 +1085,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Check args failed.";
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto h = args.host_shape[2];
auto w = args.host_shape[3];
auto co = args.device_shape[4];
auto c0 = args.device_shape[5];
auto n = args.host_shape[kN];
auto c = args.host_shape[kC];
auto h = args.host_shape[kH];
auto w = args.host_shape[kW];
const int co_idx = 4;
const int c0_idx = 5;
auto co = args.device_shape[co_idx];
auto c0 = args.device_shape[c0_idx];
for (size_t n_i = 0; n_i < n; n_i++) {
for (size_t c_i = 0; c_i < c; c_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {

View File

@ -62,6 +62,7 @@ add_dependencies(engine-datasetops-source core)
add_dependencies(engine-datasetops-source-sampler core)
add_dependencies(engine-datasetops core)
add_dependencies(engine-opt core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core)
add_dependencies(engine core)
add_dependencies(text core)
@ -81,6 +82,7 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-gnn>
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine>
@ -106,10 +108,11 @@ target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
else()
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
endif()
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs
mindspore::opencv_imgproc mindspore::tinyxml2)
mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB})
if (ENABLE_GPUQUE)
target_link_libraries(_c_dataengine PRIVATE gpu_queue
${CUDNN_PATH}/lib64/libcudnn.so

View File

@ -19,56 +19,64 @@
#include <map>
#include "common/utils.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp},
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp}};
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
{kConcat, &DEPipeline::ParseConcatOp},
{kRename, &DEPipeline::ParseRenameOp},
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
{kGenerator, &DEPipeline::ParseGeneratorOp},
{kTfReader, &DEPipeline::ParseTFReaderOp},
{kProject, &DEPipeline::ParseProjectOp},
{kTake, &DEPipeline::ParseTakeOp},
{kImageFolder, &DEPipeline::ParseImageFolderOp},
{kMnist, &DEPipeline::ParseMnistOp},
{kManifest, &DEPipeline::ParseManifestOp},
{kVoc, &DEPipeline::ParseVOCOp},
{kCoco, &DEPipeline::ParseCocoOp},
{kCifar10, &DEPipeline::ParseCifar10Op},
{kCifar100, &DEPipeline::ParseCifar100Op},
{kCelebA, &DEPipeline::ParseCelebAOp},
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
@ -292,70 +300,6 @@ Status DEPipeline::SetBatchParameters(const py::dict &args) {
return Status::OK();
}
Status DEPipeline::ValidateArgStorageOp(const py::dict &args) {
// Required arguments
if (((args.contains("dataset_files") && args["dataset_files"].is_none()) || args["schema"].is_none()) &&
((args.contains("dataset_dir") && args["dataset_dir"].is_none()) ||
(args["schema"].is_none() && args["schema_json_string"].is_none()))) {
std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
Status DEPipeline::ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
RETURN_IF_NOT_OK(ValidateArgStorageOp(args));
std::shared_ptr<StorageOp::Builder> builder;
if (args.contains("dataset_files") && !args["dataset_files"].is_none()) {
builder = std::make_shared<StorageOp::Builder>();
(void)builder->SetDatasetFileList(ToStringVector(args["dataset_files"]));
(void)builder->SetSchemaFile(ToString(args["schema"]));
} else if (args.contains("dataset_dir") && !args["dataset_dir"].is_none()) {
builder = std::make_shared<StorageOp::Builder>();
(void)builder->SetDatasetFilesDir(ToString(args["dataset_dir"]));
if (!args["schema"].is_none()) {
(void)builder->SetSchemaFile(ToString(args["schema"]));
} else if (!args["schema_json_string"].is_none()) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
std::string s = ToString(args["schema_json_string"]);
RETURN_IF_NOT_OK(schema->LoadSchemaString(s, std::vector<std::string>()));
(void)builder->SetNumRows(schema->num_rows());
(void)builder->SetSchema(std::move(schema));
}
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "prefetch_size") {
(void)builder->SetOpConnectorSize(ToInt(value));
} else if (key == "columns_list") {
(void)builder->SetColumnsToLoad(ToStringVector(value));
} else if (key == "distribution") {
(void)builder->SetDataDistributionFile(ToString(value));
} else if (key == "labels_filename") {
(void)builder->setLabelsFileName(ToString(value));
} else if (key == "dataset_usage") {
(void)builder->SetDatasetUsage(ToString(value));
}
}
}
(void)builder->SetBatchSize(temp_batch_size_);
(void)builder->SetDropRemainder(temp_drop_remainder_);
std::shared_ptr<StorageOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
num_rows_ = op->num_rows();
num_classes_ = op->num_classes();
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
if (!args["buffer_size"].is_none()) {
@ -382,35 +326,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
return Status::OK();
}
Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
if (args["partitions"].is_none()) {
std::string err_msg = "Error: partitions is not set (None)";
RETURN_STATUS_UNEXPECTED(err_msg);
}
py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
for (auto l : list) {
if (!l.is_none()) {
in_partitions->push_back(ToInt(l));
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded) {
auto sampler = py::reinterpret_borrow<py::object>(handle);
auto create = sampler.attr("create_for_minddataset");
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
while (op != nullptr) {
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
if (sampler_op && num_padded > 0) {
sampler_op->SetNumPaddedSamples(num_padded);
stack_ops.push(sampler_op);
} else {
stack_ops.push(op);
}
op = op->GetChildOp();
}
if (in_partitions->size() != 2) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
while (!stack_ops.empty()) {
operators->push_back(stack_ops.top());
stack_ops.pop();
}
constexpr int kMaxPartitions = 64;
if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
std::string err_msg = "Error: partitions is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
return Status::OK();
}
@ -438,6 +374,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
(void)builder->SetColumnsToLoad(in_col_names);
}
if (!args["padded_sample"].is_none()) {
(void)builder->SetPaddedSample(args["padded_sample"]);
(void)builder->SetNumToPadSamples(ToInt(args["num_padded"]));
}
std::vector<std::shared_ptr<mindrecord::ShardOperator>> operators;
for (auto arg : args) {
std::string key = py::str(arg.first);
@ -447,27 +387,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
(void)builder->SetNumMindRecordWorkers(ToInt(value));
} else if (key == "block_reader" && ToBool(value) == true) {
(void)builder->SetBlockReader();
} else if (key == "global_shuffle" && ToBool(value) == true) {
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset");
std::shared_ptr<mindrecord::ShardOperator> sample_op =
create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
operators.push_back(sample_op);
int num_padded = 0;
if (!args["num_padded"].is_none()) {
num_padded = ToInt(args["num_padded"]);
}
RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded));
}
}
}
std::vector<int> in_partitions;
if (!args["partitions"].is_none()) {
auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
if (Status::OK() != ret) {
return ret;
}
operators.push_back(std::make_shared<mindrecord::ShardSample>(1, in_partitions[0], in_partitions[1]));
}
if (!operators.empty()) {
(void)builder->SetOperators(operators);
}
@ -493,6 +422,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetInColNames(in_col_names);
} else if (key == "output_columns") {
(void)builder->SetOutColNames(ToStringVector(value));
} else if (key == "columns_order") {
(void)builder->SetColOrder(ToStringVector(value));
} else if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "prefetch_size") {
@ -642,18 +573,8 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetColumnsToMap(ToStringVector(value));
}
if (key == "pad_info") {
std::map<std::string, std::pair<TensorShape, float>> pad_info;
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
if (!p.second.is_none()) {
py::tuple tp = py::reinterpret_borrow<py::tuple>(p.second);
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
float pad_val = tp[1].is_none() ? 0 : ToFloat(tp[1]);
(void)pad_info.insert({ToString(p.first), {shape, pad_val}});
} else { // tuple is None
(void)pad_info.insert({ToString(p.first), {TensorShape({}), 0}});
}
}
PadInfo pad_info;
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
(void)builder->SetPaddingMap(pad_info, true);
}
}
@ -665,6 +586,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries",
"bucket_batch_sizes"};
for (auto name : mandatory_arguments) {
if (args[name.c_str()].is_none()) {
std::string err_msg = "Error: " + name + " is not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>(
ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]),
ToIntVector(args[mandatory_arguments[2].c_str()]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "length_dependent_columns") {
(void)builder->SetLengthDependentColumns(ToStringVector(value));
}
if (key == "bucket_boundaries") {
(void)builder->SetBucketBoundaries(ToIntVector(value));
}
if (key == "bucket_batch_sizes") {
(void)builder->SetBucketBatchSizes(ToIntVector(value));
}
if (key == "element_length_function") {
(void)builder->SetElementLengthFunction(value.cast<py::function>());
}
if (key == "pad_info") {
PadInfo pad_info;
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
(void)builder->SetPadInfo(pad_info);
}
if (key == "pad_to_bucket_boundary") {
(void)builder->SetPadToBucketBoundary(ToBool(value));
}
if (key == "drop_remainder") {
(void)builder->SetDropRemainder(ToBool(value));
}
}
}
std::shared_ptr<BucketBatchByLengthOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1
@ -801,6 +772,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetColumnsToLoad(columns_to_load);
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
(void)builder->SetShuffleGlobal(ToBool(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
@ -856,9 +829,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -893,9 +864,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -922,6 +891,16 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["task"].is_none()) {
std::string err_msg = "Error: No task specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["mode"].is_none()) {
std::string err_msg = "Error: No mode specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetTask(ToString(args["task"]));
@ -930,9 +909,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -951,6 +928,47 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK();
}
Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["dataset_dir"].is_none()) {
std::string err_msg = "Error: No dataset path specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["annotation_file"].is_none()) {
std::string err_msg = "Error: No annotation_file specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (args["task"].is_none()) {
std::string err_msg = "Error: No task specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>();
(void)builder->SetDir(ToString(args["dataset_dir"]));
(void)builder->SetFile(ToString(args["annotation_file"]));
(void)builder->SetTask(ToString(args["task"]));
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
}
}
}
std::shared_ptr<CocoOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments
if (args["dataset_dir"].is_none()) {
@ -966,9 +984,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1001,9 +1017,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1039,10 +1053,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "num_samples") {
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "num_samples") {
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(void)builder.SetTotalRows(ToInt(value));
}
}
if (schema_exists) {
@ -1077,9 +1093,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_parallel_workers") {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
@ -1121,8 +1135,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetDecode(ToBool(value));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "dataset_type") {
(void)builder->SetDatasetType(ToString(value));
}
@ -1152,8 +1164,10 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
(void)builder->SetShuffleGlobal(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
(void)builder->SetTotalRows(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {
@ -1166,5 +1180,106 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
if (!p.second.is_none()) {
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
std::shared_ptr<Tensor> pad_val = nullptr;
if (py::isinstance<py::str>(tp[1])) {
std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(
Tensor::CreateTensor(&pad_val, std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar()),
"Cannot create pad_value Tensor");
} else {
float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]);
CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(),
DataType(DataType::DE_FLOAT32)),
"Cannot create pad_value Tensor");
pad_val->SetItemAt<float>({}, pad_val_float);
}
(void)pad_info->insert({ToString(p.first), {shape, pad_val}});
} else { // tuple is None
(void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}});
}
}
return Status::OK();
}
Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>();
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "freq_range") {
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
} else if (key == "top_k") {
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
} else if (key == "columns") {
(void)builder->SetColumnNames(ToStringVector(value));
} else if (key == "vocab") {
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
} else if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "special_first") {
(void)builder->SetSpecialFirst(ToBool(value));
} else if (key == "special_tokens") {
(void)builder->SetSpecialTokens(ToStringVector(value));
}
}
}
std::shared_ptr<BuildVocabOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
if (!args["dataset_files"].is_none()) {
(void)builder->SetClueFilesList(ToStringVector(args["dataset_files"]));
} else {
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
}
// Optional arguments
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "shuffle_files") {
(void)builder->SetShuffleFiles(ToBool(value));
} else if (key == "shuffle_global") {
(void)builder->SetShuffleGlobal(ToBool(value));
} else if (key == "num_samples") {
(void)builder->SetNumSamples(ToInt(value));
} else if (key == "num_shards") {
(void)builder->SetNumDevices(ToInt(value));
} else if (key == "shard_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "cols_to_keyword") {
std::map<std::string, std::string> map_dict;
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
if (!p.second.is_none()) {
map_dict.insert({ToString(p.first), ToString(p.second)});
} else {
map_dict.insert({ToString(p.first), ToString(p.first)});
}
}
(void)builder->SetColsKeyMap(map_dict);
}
}
}
std::shared_ptr<ClueOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <iostream>
#include <memory>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
@ -36,10 +37,10 @@ using DsOpPtr = std::shared_ptr<DatasetOp>;
// enum for the dataset operator names
enum OpName {
kStorage = 0,
kShuffle,
kMindrecord,
kBatch,
kBucketBatch,
kBarrier,
kCache,
kRepeat,
@ -58,11 +59,14 @@ enum OpName {
kMnist,
kManifest,
kVoc,
kCoco,
kCifar10,
kCifar100,
kCelebA,
kRandomData,
kTextFile
kTextFile,
kBuildVocab,
kClue
};
// The C++ binder class that we expose to the python script.
@ -100,14 +104,14 @@ class DEPipeline {
int GetRepeatCount() const;
Status ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr);
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status BuildMindrecordSamplerChain(const py::handle &handle,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
int num_padded);
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
@ -118,6 +122,8 @@ class DEPipeline {
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
@ -142,6 +148,8 @@ class DEPipeline {
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
@ -160,14 +168,17 @@ class DEPipeline {
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
private:
// Execution tree that links the dataset operators.
std::shared_ptr<ExecutionTree> tree_;
std::unique_ptr<DatasetIterator> iterator_;
// Validate required args passed to storage op.
Status ValidateArgStorageOp(const py::dict &args);
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
int batch_size_;
int repeat_num_;

View File

@ -16,8 +16,37 @@
#include <exception>
#include "dataset/api/de_pipeline.h"
#include "dataset/kernels/no_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
#include "dataset/kernels/data/duplicate_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/one_hot_op.h"
#include "dataset/kernels/data/pad_end_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/image/bounding_box_augment_op.h"
#include "dataset/kernels/image/center_crop_op.h"
#include "dataset/kernels/image/cut_out_op.h"
#include "dataset/kernels/image/decode_op.h"
@ -26,51 +55,51 @@
#include "dataset/kernels/image/normalize_op.h"
#include "dataset/kernels/image/pad_op.h"
#include "dataset/kernels/image/random_color_adjust_op.h"
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
#include "dataset/kernels/image/random_crop_and_resize_op.h"
#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
#include "dataset/kernels/image/random_crop_op.h"
#include "dataset/kernels/image/random_crop_with_bbox_op.h"
#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h"
#include "dataset/kernels/image/random_horizontal_flip_op.h"
#include "dataset/kernels/image/random_resize_op.h"
#include "dataset/kernels/image/random_rotation_op.h"
#include "dataset/kernels/image/random_vertical_flip_op.h"
#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
#include "dataset/kernels/image/rescale_op.h"
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/kernels/no_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "dataset/text/vocab.h"
#include "dataset/text/kernels/lookup_op.h"
#include "dataset/text/kernels/ngram_op.h"
#include "dataset/text/kernels/to_number_op.h"
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
#include "dataset/text/kernels/wordpiece_tokenizer_op.h"
#include "dataset/text/vocab.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_distributed_sample.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_pk_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sequential_sample.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
#ifdef ENABLE_ICU4C
#include "dataset/text/kernels/basic_tokenizer_op.h"
#include "dataset/text/kernels/bert_tokenizer_op.h"
#include "dataset/text/kernels/case_fold_op.h"
#include "dataset/text/kernels/normalize_utf8_op.h"
#include "dataset/text/kernels/regex_replace_op.h"
#include "dataset/text/kernels/regex_tokenizer_op.h"
#include "dataset/text/kernels/unicode_script_tokenizer_op.h"
#include "dataset/text/kernels/whitespace_tokenizer_op.h"
#endif
namespace py = pybind11;
namespace mindspore {
@ -143,51 +172,49 @@ void bindDatasetOps(py::module *m) {
});
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
int64_t count = 0;
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
return count;
});
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
.def_static("get_num_rows_and_classes", [](const std::string &path) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
return py::make_tuple(count, num_classes);
});
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
.def_static("get_num_rows",
[](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "_create_for_minddataset")) {
auto create = sampler.attr("_create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
return count;
});
.def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler,
const int64_t num_padded) {
int64_t count = 0;
std::shared_ptr<mindrecord::ShardOperator> op;
if (py::hasattr(sampler, "create_for_minddataset")) {
auto create = sampler.attr("create_for_minddataset");
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
}
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
return count;
});
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
.def_static("get_num_rows_and_classes",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
[](const std::string &file, const py::dict &dict, const std::string &usage) {
int64_t count = 0, num_classes = 0;
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
return py::make_tuple(count, num_classes);
})
.def_static("get_class_indexing",
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
return output_class_indexing;
});
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
.def_static("get_num_rows", [](const std::string &dir) {
int64_t count = 0;
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
return count;
});
@ -201,20 +228,44 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
return count;
});
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
.def_static("get_num_rows", [](const py::list &files) {
int64_t count = 0;
std::vector<std::string> filenames;
for (auto file : files) {
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
}
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
const std::string &task_mode, const py::dict &dict) {
std::map<std::string, int32_t> output_class_indexing;
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
return output_class_indexing;
});
(void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
.def_static("get_class_indexing",
[](const std::string &dir, const std::string &file, const std::string &task) {
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
return output_class_indexing;
})
.def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) {
int64_t count = 0;
THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
return count;
});
}
void bindTensor(py::module *m) {
(void)py::class_<GlobalContext>(*m, "GlobalContext")
@ -227,12 +278,14 @@ void bindTensor(py::module *m) {
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_seed", &ConfigManager::set_seed)
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
.def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed)
.def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
.def(py::init([](py::array arr) {
@ -300,6 +353,11 @@ void bindTensorOps1(py::module *m) {
.def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
py::arg("NumOps") = UniformAugOp::kDefNumOps);
(void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>(
*m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
.def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"),
py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio);
(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
*m, "ResizeBilinearOp",
"Tensor operation to resize an image using "
@ -314,6 +372,11 @@ void bindTensorOps1(py::module *m) {
(void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
*m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
.def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
(void)py::class_<RandomHorizontalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomHorizontalFlipWithBBoxOp>>(
*m, "RandomHorizontalFlipWithBBoxOp",
"Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.")
.def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability);
}
void bindTensorOps2(py::module *m) {
@ -321,6 +384,12 @@ void bindTensorOps2(py::module *m) {
*m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
.def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
(void)py::class_<RandomVerticalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomVerticalFlipWithBBoxOp>>(
*m, "RandomVerticalFlipWithBBoxOp",
"Tensor operation to randomly flip an image vertically"
" and adjust bounding boxes.")
.def(py::init<float>(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability);
(void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
"Gives random crop of specified size "
"Takes crop size")
@ -332,10 +401,84 @@ void bindTensorOps2(py::module *m) {
py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
(void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
(void)py::class_<RandomCropWithBBoxOp, TensorOp, std::shared_ptr<RandomCropWithBBoxOp>>(*m, "RandomCropWithBBoxOp",
"Gives random crop of given "
"size + adjusts bboxes "
"Takes crop size")
.def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop,
py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom,
py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft,
py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight,
py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType,
py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded,
py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG,
py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB);
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
.def(py::init<int32_t>());
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
.def(py::init<std::shared_ptr<Tensor>>());
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor slice operation.")
.def(py::init<bool>())
.def(py::init([](const py::list &py_list) {
std::vector<dsize_t> c_list;
for (auto l : py_list) {
if (!l.is_none()) {
c_list.push_back(py::reinterpret_borrow<py::int_>(l));
}
}
return std::make_shared<SliceOp>(c_list);
}))
.def(py::init([](const py::tuple &py_slice) {
if (py_slice.size() != 3) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
Slice c_slice;
if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]),
py::reinterpret_borrow<py::int_>(py_slice[2]));
} else if (py_slice[0].is_none() && py_slice[2].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
} else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]));
}
if (!c_slice.valid()) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
}
return std::make_shared<SliceOp>(c_slice);
}));
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
.value("EQ", RelationalOp::kEqual)
.value("NE", RelationalOp::kNotEqual)
.value("LT", RelationalOp::kLess)
.value("LE", RelationalOp::kLessEqual)
.value("GT", RelationalOp::kGreater)
.value("GE", RelationalOp::kGreaterEqual)
.export_values();
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
"Tensor mask operation using relational comparator")
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp", "Duplicate tensor.")
.def(py::init<>());
(void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
*m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
.def(py::init<int64_t>());
(void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp",
"Tensor operation concatenate tensors.")
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
py::arg("prepend").none(true), py::arg("append").none(true));
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
*m, "RandomRotationOp",
"Tensor operation to apply RandomRotation."
@ -347,6 +490,10 @@ void bindTensorOps2(py::module *m) {
py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
}
void bindTensorOps3(py::module *m) {
@ -364,6 +511,20 @@ void bindTensorOps3(py::module *m) {
py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
(void)py::class_<RandomCropAndResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomCropAndResizeWithBBoxOp>>(
*m, "RandomCropAndResizeWithBBoxOp",
"Tensor operation to randomly crop an image (with BBoxes) and resize to a given size."
"Takes output height and width and"
"optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
"interpolation mode, and max attempts to crop")
.def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb,
py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb,
py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb,
py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb,
py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation,
py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter);
(void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
*m, "RandomColorAdjustOp",
"Tensor operation to adjust an image's color randomly."
@ -418,9 +579,13 @@ void bindTensorOps4(py::module *m) {
.def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
"TensorOp to convert strings to numbers.")
.def(py::init<DataType>(), py::arg("data_type"))
.def(py::init<std::string>(), py::arg("data_type"));
}
void bindTensorOps5(py::module *m) {
void bindTokenizerOps(py::module *m) {
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
.def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
py::arg("mode") = JiebaMode::kMix)
@ -433,6 +598,60 @@ void bindTensorOps5(py::module *m) {
"Tensor operation to LookUp each word")
.def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
.def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping")
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
const std::string &>(),
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
py::arg("separator"));
(void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
*m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &>(),
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken));
}
void bindDependIcuTokenizerOps(py::module *m) {
#ifdef ENABLE_ICU4C
(void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
*m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
.def(py::init<>());
(void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
*m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
.def(py::init<>())
.def(py::init<bool>(), py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace);
(void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
*m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
.def(py::init<>());
(void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
*m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
.def(py::init<>())
.def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
(void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
*m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
.def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"), py::arg("replace"),
py::arg("replace_all"));
(void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
*m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
.def(py::init<const std::string &, const std::string &>(), py::arg("delim_pattern"), py::arg("keep_delim_pattern"));
(void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
*m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
.def(py::init<bool, bool, NormalizeForm, bool>(), py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken);
(void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(*m, "BertTokenizerOp",
"Tokenizer used for Bert text process.")
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, bool, bool,
NormalizeForm, bool>(),
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken);
#endif
}
void bindSamplerOps(py::module *m) {
@ -449,32 +668,29 @@ void bindSamplerOps(py::module *m) {
.def("add_child",
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
py::arg("seed"));
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
.def(py::init<int64_t, int64_t, bool>());
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
py::arg("num_samples"))
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
.def(py::init<int64_t, bool, bool>());
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
.def(py::init<>());
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
.def(py::init<int64_t, int64_t>());
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
.def(py::init<int64_t, std::vector<int64_t>>());
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
*m, "MindrecordSubsetRandomSampler")
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
@ -486,12 +702,27 @@ void bindSamplerOps(py::module *m) {
}
}));
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
}));
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
.def(py::init([](int num_samples, int start_index) {
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
}));
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
py::arg("replacement"));
.def(py::init<int64_t, std::vector<double>, bool>());
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
.def(py::init<py::object>(), py::arg("pySampler"));
.def(py::init<int64_t, py::object>());
}
void bindInfoObjects(py::module *m) {
@ -503,16 +734,18 @@ void bindInfoObjects(py::module *m) {
void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
.def_static("from_list",
[](const py::list &words) {
[](const py::list &words, const py::list &special_tokens, bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
return v;
})
.def_static("from_file",
[](const std::string &path, const std::string &dlm, int32_t vocab_size) {
[](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
bool special_first) {
std::shared_ptr<Vocab> v;
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
return v;
})
.def_static("from_dict", [](const py::dict &words) {
@ -529,10 +762,22 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g_out->Init());
return g_out;
}))
.def("get_nodes",
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
.def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out;
})
.def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out;
})
.def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out;
})
.def("get_all_neighbors",
@ -541,12 +786,38 @@ void bindGraphData(py::module *m) {
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out;
})
.def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out;
})
.def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out;
})
.def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow();
})
.def("graph_info",
[](gnn::Graph &g) {
py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out));
return out;
});
})
.def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out;
});
}
// This is where we externalize the C logic as python modules
@ -555,9 +826,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
(void)py::enum_<OpName>(m, "OpName", py::arithmetic())
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BUCKETBATCH", OpName::kBucketBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
@ -578,11 +849,14 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("MNIST", OpName::kMnist)
.value("MANIFEST", OpName::kManifest)
.value("VOC", OpName::kVoc)
.value("COCO", OpName::kCoco)
.value("CIFAR10", OpName::kCifar10)
.value("CIFAR100", OpName::kCifar100)
.value("RANDOMDATA", OpName::kRandomData)
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile);
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)
@ -590,6 +864,16 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
.export_values();
#ifdef ENABLE_ICU4C
(void)py::enum_<NormalizeForm>(m, "NormalizeForm", py::arithmetic())
.value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
.value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
.value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
.value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
.value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
.export_values();
#endif
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
.value("DE_INTER_CUBIC", InterpolationMode::kCubic)
@ -609,12 +893,13 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindTensorOps2(&m);
bindTensorOps3(&m);
bindTensorOps4(&m);
bindTensorOps5(&m);
bindTokenizerOps(&m);
bindSamplerOps(&m);
bindDatasetOps(&m);
bindInfoObjects(&m);
bindVocabObjects(&m);
bindGraphData(&m);
bindDependIcuTokenizerOps(&m);
}
} // namespace dataset
} // namespace mindspore

View File

@ -11,6 +11,7 @@ add_library(core OBJECT
data_type.cc
global_context.cc
tensor.cc
tensor_row.cc
tensor_shape.cc
)
add_dependencies(core mindspore::protobuf)

View File

@ -27,6 +27,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/build_vocab_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
@ -38,7 +39,6 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"

View File

@ -48,7 +48,7 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
Status ConfigManager::LoadFile(const std::string &settingsFile) {
Status rc;
if (!Path(settingsFile).Exists()) {
RETURN_STATUS_UNEXPECTED("File is not found");
RETURN_STATUS_UNEXPECTED("File is not found.");
}
// Some settings are mandatory, others are not (with default). If a setting
// is optional it will set a default value if the config is missing from the file.
@ -59,14 +59,11 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) {
rc = FromJson(js);
} catch (const nlohmann::json::type_error &e) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << e.what();
ss << "Client file failed to load:\n" << e.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
} catch (const std::exception &err) {
std::ostringstream ss;
ss << "Client settings failed to load:\n" << err.what();
std::string err_msg = ss.str();
RETURN_STATUS_UNEXPECTED(err_msg);
RETURN_STATUS_UNEXPECTED("Client file failed to load.");
}
return rc;
}
@ -88,5 +85,7 @@ void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector
uint32_t ConfigManager::seed() const { return seed_; }
void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
} // namespace dataset
} // namespace mindspore

View File

@ -111,12 +111,21 @@ class ConfigManager {
// @param seed - The default seed to use
void set_seed(uint32_t seed);
// setter function
// @param interval - The setting to apply to the config
void set_monitor_sampling_interval(uint32_t interval);
// getter function
// @return The iterval of monitor sampling
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
private:
int32_t rows_per_buffer_{kCfgRowsPerBuffer};
int32_t num_parallel_workers_{kCfgParallelWorkers};
int32_t worker_connector_size_{kCfgWorkerConnectorSize};
int32_t op_connector_size_{kCfgOpConnectorSize};
uint32_t seed_{kCfgDefaultSeed};
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
// Private helper function that taks a nlohmann json format and populates the settings
// @param j - The json nlohmann json info

View File

@ -47,9 +47,13 @@ constexpr uint32_t kCfgParallelWorkers = 4;
constexpr uint32_t kCfgWorkerConnectorSize = 16;
constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;
using connection_id_type = int64_t;
using row_id_type = int64_t;
} // namespace dataset
} // namespace mindspore

View File

@ -138,7 +138,7 @@ DataType DataType::FromNpArray(const py::array &arr) {
return DataType(DataType::DE_FLOAT32);
} else if (py::isinstance<py::array_t<std::double_t>>(arr)) {
return DataType(DataType::DE_FLOAT64);
} else if (arr.dtype().kind() == 'S') {
} else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') {
return DataType(DataType::DE_STRING);
} else {
MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!";

View File

@ -128,7 +128,9 @@ class DataType {
// @tparam T
// @return true or false
template <typename T>
bool IsCompatible() const;
bool IsCompatible() const {
return type_ == FromCType<T>();
}
// returns true if the template type is the same as the Tensor type_
// @tparam T
@ -146,6 +148,9 @@ class DataType {
return out;
}
template <typename T>
static DataType FromCType();
// Convert from DataType to Pybind type
// @return
py::dtype AsNumpyType() const;
@ -191,68 +196,68 @@ class DataType {
};
template <>
inline bool DataType::IsCompatible<bool>() const {
return type_ == DataType::DE_BOOL;
inline DataType DataType::FromCType<bool>() {
return DataType(DataType::DE_BOOL);
}
template <>
inline bool DataType::IsCompatible<double>() const {
return type_ == DataType::DE_FLOAT64;
inline DataType DataType::FromCType<double>() {
return DataType(DataType::DE_FLOAT64);
}
template <>
inline bool DataType::IsCompatible<float>() const {
return type_ == DataType::DE_FLOAT32;
inline DataType DataType::FromCType<float>() {
return DataType(DataType::DE_FLOAT32);
}
template <>
inline bool DataType::IsCompatible<float16>() const {
return type_ == DataType::DE_FLOAT16;
inline DataType DataType::FromCType<float16>() {
return DataType(DataType::DE_FLOAT16);
}
template <>
inline bool DataType::IsCompatible<int64_t>() const {
return type_ == DataType::DE_INT64;
inline DataType DataType::FromCType<int64_t>() {
return DataType(DataType::DE_INT64);
}
template <>
inline bool DataType::IsCompatible<uint64_t>() const {
return type_ == DataType::DE_UINT64;
inline DataType DataType::FromCType<uint64_t>() {
return DataType(DataType::DE_UINT64);
}
template <>
inline bool DataType::IsCompatible<int32_t>() const {
return type_ == DataType::DE_INT32;
inline DataType DataType::FromCType<int32_t>() {
return DataType(DataType::DE_INT32);
}
template <>
inline bool DataType::IsCompatible<uint32_t>() const {
return type_ == DataType::DE_UINT32;
inline DataType DataType::FromCType<uint32_t>() {
return DataType(DataType::DE_UINT32);
}
template <>
inline bool DataType::IsCompatible<int16_t>() const {
return type_ == DataType::DE_INT16;
inline DataType DataType::FromCType<int16_t>() {
return DataType(DataType::DE_INT16);
}
template <>
inline bool DataType::IsCompatible<uint16_t>() const {
return type_ == DataType::DE_UINT16;
inline DataType DataType::FromCType<uint16_t>() {
return DataType(DataType::DE_UINT16);
}
template <>
inline bool DataType::IsCompatible<int8_t>() const {
return type_ == DataType::DE_INT8;
inline DataType DataType::FromCType<int8_t>() {
return DataType(DataType::DE_INT8);
}
template <>
inline bool DataType::IsCompatible<uint8_t>() const {
return type_ == DataType::DE_UINT8;
inline DataType DataType::FromCType<uint8_t>() {
return DataType(DataType::DE_UINT8);
}
template <>
inline bool DataType::IsCompatible<std::string_view>() const {
return type_ == DataType::DE_STRING;
inline DataType DataType::FromCType<std::string_view>() {
return DataType(DataType::DE_STRING);
}
template <>

View File

@ -18,6 +18,7 @@
#include <algorithm>
#include <iomanip>
#include <iostream>
#include <fstream>
#include <memory>
#include <vector>
#include <utility>
@ -229,7 +230,12 @@ Status Tensor::CreateTensorFromNumpyString(std::shared_ptr<Tensor> *ptr, py::arr
}
arr.resize({arr.size()}); // flatten the py::array so we can iterate once
std::vector<std::string> strings;
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
if (arr.dtype().kind() == 'U') {
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::str>(s)); });
} else {
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
}
arr.resize(shape); // resize arr back to the original shape
@ -306,6 +312,50 @@ Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::Byte
return Status::OK();
}
Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const std::string &file_path) {
std::ifstream fs;
fs.open(file_path, std::ios::binary | std::ios::in);
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path);
int64_t num_bytes = fs.seekg(0, std::ios::end).tellg();
CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file");
RETURN_IF_NOT_OK(
Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8)));
int64_t written_bytes = fs.read(reinterpret_cast<char *>((*ptr)->GetMutableBuffer()), num_bytes).gcount();
CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor");
fs.close();
return Status::OK();
}
Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
const TensorShape &shape, const DataType &type, dsize_t pad_size) {
RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type));
unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer();
int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size;
for (int i = 0; i < bytes_list.value_size(); i++) {
// read string data into tensor
const std::string &current_element = bytes_list.value(i);
int return_code =
memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size());
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor");
current_tensor_addr += current_element.size();
tensor_bytes_remaining -= current_element.size();
// pad
int64_t chars_to_pad = pad_size - current_element.size();
return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast<int>(' '), chars_to_pad);
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor");
current_tensor_addr += chars_to_pad;
tensor_bytes_remaining -= chars_to_pad;
}
return Status::OK();
}
// Memcpy the given strided array's used part to consecutive memory
// Consider a 3-d array
// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]]
@ -539,11 +589,13 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_
if (type() == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet.");
}
dsize_t flat_ind;
std::vector<dsize_t> t_shape = shape().AsVector();
std::vector<dsize_t> r(t_shape.begin() + ind.size(), t_shape.end());
*remaining = TensorShape(r);
ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0);
RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind));
// check if GetBuffer() returns null, we should flag this as an error, this sanity check will only
// be true is the tensor failed to allocate memory.
@ -584,6 +636,39 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p
}
}
Status Tensor::Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &tensor) {
std::string err_msg;
err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : "";
err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : "";
err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : "";
err_msg +=
(index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : "";
err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : "";
uchar *start_addr_of_ind = nullptr;
TensorShape remaining_shape = tensor->shape();
StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape);
err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : "";
if (!err_msg.empty()) {
MS_LOG(DEBUG) << "Insert tensor message: " << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
int ret_code =
memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes());
if (ret_code == 0) {
return Status::OK();
} else {
err_msg += "[Tensor] error in memcpy_s when inserting tensor\n";
MS_LOG(DEBUG) << "Tensor message: " << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
}
Status Tensor::ExpandDim(const dsize_t &axis) {
if (axis > Rank()) {
std::string err = "Axis is out of bound";
@ -649,7 +734,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
Status Tensor::GetItemAt(std::string_view *o, const std::vector<dsize_t> &index) const {
RETURN_UNEXPECTED_IF_NULL(data_);
RETURN_UNEXPECTED_IF_NULL(o);
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING");
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string");
uchar *start = nullptr;
offset_t length = 0;
@ -699,6 +784,8 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
for (; itr != end<std::string_view>(); itr++) {
max = std::max((*itr).length(), max);
}
// if all strings are empty, numpy stores a byte for each string |S1
max = (max == 0 ? 1 : max);
uint64_t total_size = shape_.NumOfElements() * max;
char *tmp_data = reinterpret_cast<char *>(data_allocator_->allocate(total_size));
if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array.");
@ -708,8 +795,10 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
itr = begin<std::string_view>();
uint64_t i = 0;
for (; itr != end<std::string_view>(); itr++, i++) {
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
if (!(*itr).empty()) {
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
}
}
auto strides = shape_.Strides();
std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; });
@ -847,6 +936,78 @@ Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length
*length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length
return Status::OK();
}
Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index) {
CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type");
CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0");
uint8_t type_size = type_.SizeInBytes();
size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size;
dsize_t src_flat_ind = 0, dst_flat_ind = 0;
RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind));
RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind));
const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size;
unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size;
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
return Status::OK();
}
Status Tensor::Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only.");
CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty.");
if (type_.IsNumeric()) {
return SliceNumeric(out, indices);
} else {
return SliceString(out, indices);
}
}
Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
RETURN_IF_NOT_OK(
CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(indices.size())}), type_));
(*out)->GetMutableBuffer();
dsize_t out_index = 0;
dsize_t dim_length = shape_[0];
dsize_t type_size = type_.SizeInBytes();
dsize_t src_start = HandleNeg(indices[0], dim_length);
uchar *dst_addr = (*out)->data_;
dsize_t count = 1;
for (dsize_t i = 0; i < indices.size(); i++) {
dsize_t cur_index = HandleNeg(indices[i], dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")");
if (i < indices.size() - 1) {
dsize_t next_index = HandleNeg(indices[i + 1], dim_length);
if (next_index == cur_index + 1) {
count++;
continue;
}
}
int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
count * type_size);
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
out_index += count;
if (i < indices.size() - 1) {
src_start = HandleNeg(indices[i + 1], dim_length); // next index
}
count = 1;
}
return Status::OK();
}
Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
dsize_t dim_length = shape_[0];
std::vector<std::string> strings;
for (dsize_t index : indices) {
dsize_t cur_index = HandleNeg(index, dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")");
std::string_view sv;
GetItemAt(&sv, {cur_index});
strings.emplace_back(sv);
}
return CreateTensor(out, strings);
}
} // namespace dataset
} // namespace mindspore

View File

@ -44,9 +44,6 @@ class Tensor;
using CharAllocPtr = std::unique_ptr<Allocator<unsigned char>>;
using TensorAllocPtr = std::shared_ptr<Allocator<Tensor>>; // An allocator shared_ptr for Tensors
using TensorRow = std::vector<std::shared_ptr<Tensor>>; // A row is a set of Tensor pointers
using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows
using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality
class Tensor {
public:
@ -118,6 +115,16 @@ class Tensor {
static Status CreateTensor(std::shared_ptr<Tensor> *, TensorImpl tensor_impl, const TensorShape &shape, DataType type,
const unsigned char *data = nullptr);
/// Create a copy of the input tensor
/// \param out [out] output tensor to be generated
/// \param in [in] orginal tensor to be copied
/// \return Status
static Status CreateTensor(std::shared_ptr<Tensor> *out, const std::shared_ptr<Tensor> &in) {
const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator();
*out = std::allocate_shared<Tensor>(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes());
return Status::OK();
}
// A static factory method to create a Tensor from a given py::array.
// @param ptr output argument to hold the created Tensor
// @param arr py::array
@ -135,9 +142,41 @@ class Tensor {
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::vector<std::string> &strings,
const TensorShape &shape = TensorShape::CreateUnknownRankShape());
// create tensor from protobuf bytelist with strings
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
const TensorShape &shape);
// A static factory method to create a Tensor from a given list of numbers.
// @param ptr output argument to hold the created Tensor
// @param items elements of the tensor
// @param shape shape of the tensor
// @return Status Code
template <typename T>
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::vector<T> &items,
const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) {
DataType type = DataType::FromCType<T>();
auto items_ptr = reinterpret_cast<const uchar *>(&items[0]);
TensorShape shape = shape_req;
if (!shape.known()) {
shape = TensorShape({static_cast<dsize_t>(items.size())});
}
return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr);
}
// A static factory method to create a Tensor from a given number.
// @param ptr output argument to hold the created Tensor
// @param item value
// @return Status Code
template <typename T>
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const T &item) {
return CreateTensor<T>(ptr, {item}, TensorShape::CreateScalar());
}
// Create tensor from protobuf bytelist with uint8 or int8 types
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
const TensorShape &shape, const DataType &type, dsize_t pad_size);
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::string &path);
// Copy raw data of a array based on shape and strides to the destination pointer
// @param dst Pointer to the destination array where the content is to be copied
// @param src Pointer to the source of strided array to be copied
@ -260,11 +299,6 @@ class Tensor {
// @return const unsigned char*
const unsigned char *GetBuffer() const;
// Get the starting memory address for the data of the tensor. This potentially
// drives an allocation if the data area.
// @return unsigned char*
unsigned char *GetMutableBuffer();
// Getter of the type
// @return
DataType type() const { return type_; }
@ -323,6 +357,22 @@ class Tensor {
return ss.str();
}
// Handle negative indices.
static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
// Based on the type of tensor, SliceNumeric or SliceString will be called
// @param out Tensor
// @param indices vector of indices
// @return Status error code
Status Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Slice numeric tensors.
Status SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Slice string tensors
Status SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
// Constructs numpy array from input tensor
// @param data this data is the location of python data
// @return Status code
@ -332,6 +382,9 @@ class Tensor {
static Status GetBufferInfo(Tensor &t, py::buffer_info *out);
// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
// @tparam T type of values in the Tensor Iterator
@ -518,6 +571,7 @@ class Tensor {
// @return TensorIterator
template <typename T>
TensorIterator<T> begin() {
AllocateBuffer(SizeInBytes());
return TensorIterator<T>(data_);
}
@ -529,7 +583,18 @@ class Tensor {
return TensorIterator<T>(data_end_);
}
// Copies the last dimension at `index` from Tensor `src` to this Tensor.
// @param src Tensor
// @param index vector to the start of the dimension. The last dim should be 0
// @return Status
Status CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index);
protected:
// Get the starting memory address for the data of the tensor. This potentially
// drives an allocation if the data is null.
// @return unsigned char*
unsigned char *GetMutableBuffer();
// A function that prints Tensor recursively, first called by print
// @param out
// @param cur_dim

View File

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "dataset/core/tensor_row.h"
namespace py = pybind11;
namespace mindspore {
namespace dataset {
TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {}
TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {}
TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {}
TensorRow::TensorRow(row_id_type id, const std::initializer_list<value_type> &lst) : id_(id), row_(lst) {}
TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {}
TensorRow &TensorRow::operator=(const TensorRow &tr) {
if (this == &tr) {
return *this;
}
row_ = tr.row_;
id_ = tr.id_;
return *this;
}
TensorRow &TensorRow::operator=(const std::initializer_list<TensorRow::value_type> &lst) {
row_ = lst;
return *this;
}
TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {}
TensorRow::TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept
: id_(id), row_(std::move(lst)) {}
TensorRow::TensorRow(TensorRow &&tr) noexcept {
id_ = tr.id_;
row_ = std::move(tr.row_);
}
TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept {
if (this == &tr) {
return *this;
}
row_ = std::move(tr.row_);
id_ = tr.id_;
tr.id_ = kDefaultRowId;
return *this;
}
TensorRow &TensorRow::operator=(std::initializer_list<TensorRow::value_type> &&lst) noexcept {
row_ = std::move(lst);
return *this;
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,131 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_CORE_TENSOR_ROW_H_
#define DATASET_CORE_TENSOR_ROW_H_
#include <deque>
#include <memory>
#include <vector>
#include "dataset/core/tensor.h"
namespace mindspore {
namespace dataset {
class TensorRow; // A set of Tensor pointers with an id
using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows
using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality
class TensorRow {
public:
static constexpr row_id_type kDefaultRowId = -1; // Default row id
// Type definitions
using size_type = dsize_t;
using value_type = std::shared_ptr<Tensor>;
using reference = std::shared_ptr<Tensor> &;
using const_reference = const std::shared_ptr<Tensor> &;
using vector_type = std::vector<std::shared_ptr<Tensor>>;
using iterator = std::vector<std::shared_ptr<Tensor>>::iterator;
using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator;
TensorRow() noexcept;
TensorRow(size_type n, value_type t) noexcept;
// Copy Constructors
explicit TensorRow(const vector_type &v);
TensorRow(row_id_type id, const std::initializer_list<value_type> &lst);
TensorRow(const TensorRow &tr);
TensorRow &operator=(const TensorRow &tr);
TensorRow &operator=(const std::initializer_list<value_type> &lst);
// Move Constructors
explicit TensorRow(vector_type &&v) noexcept;
TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept;
TensorRow(TensorRow &&tr) noexcept;
TensorRow &operator=(TensorRow &&tr) noexcept;
TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept;
// Destructor
~TensorRow() = default;
// Functions to fetch/set id/vector
row_id_type getId() const { return id_; }
void setId(row_id_type id) { id_ = id; }
const vector_type &getRow() const { return row_; }
// Wrapper functions to support vector operations
void emplace_back(value_type t) { row_.emplace_back(t); }
void push_back(value_type t) { row_.push_back(t); }
void clear() noexcept { row_.clear(); }
size_type size() const noexcept { return row_.size(); }
void reserve(size_type size) { row_.reserve(size); }
void resize(size_type size) { row_.resize(size); }
bool empty() { return row_.empty(); }
void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); }
// Wrapper functions to support vector element access
reference at(size_type index) { return row_.at(index); }
const_reference at(size_type index) const { return row_.at(index); }
reference front() { return row_.front(); }
const_reference front() const { return row_.front(); }
reference back() { return row_.back(); }
const_reference back() const { return row_.back(); }
reference operator[](size_type index) { return row_[index]; }
const_reference operator[](size_type index) const { return row_[index]; }
// Wrapper functions to support vector iteration
iterator begin() { return row_.begin(); }
const_iterator begin() const { return row_.begin(); }
iterator end() { return row_.end(); }
const_iterator end() const { return row_.end(); }
protected:
row_id_type id_;
std::vector<std::shared_ptr<Tensor>> row_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_CORE_TENSOR_ROW_H_

View File

@ -94,7 +94,7 @@ class TensorShape {
// @return
TensorShape PrependDim(dsize_t dim) const;
// Insert a new dim at the end of the shape. For example, <2,4> --> PrependDim(4) --> <2,4,4>
// Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4>
// @param dim
// @return
TensorShape AppendDim(dsize_t dim) const;
@ -118,7 +118,10 @@ class TensorShape {
bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); }
dsize_t operator[](const dsize_t index) const { return raw_shape_[index]; }
dsize_t operator[](const dsize_t index) const {
if (index < 0) return raw_shape_[raw_shape_.size() + index];
return raw_shape_[index];
}
// Return the Shape as a vector
// @return

View File

@ -1,6 +1,7 @@
add_subdirectory(datasetops)
add_subdirectory(opt)
add_subdirectory(gnn)
add_subdirectory(perf)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
@ -16,7 +17,7 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
endif ()

View File

@ -152,6 +152,23 @@ class Connector {
return out;
}
// Get current size of connector.
int32_t size() const {
int32_t size = 0;
for (int32_t i = 0; i < queues_.size(); ++i) {
size += queues_[i]->size();
}
return size;
}
int32_t capacity() const {
int32_t capacity = 0;
for (int32_t i = 0; i < queues_.size(); ++i) {
capacity += queues_[i]->capacity();
}
return capacity;
}
// Register the internal resources with Task group for interruption service.
// @param vg
// @return

View File

@ -17,8 +17,6 @@
#include "dataset/util/allocator.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/datasetops/source/storage_client.h"
#include "dataset/engine/datasetops/source/tf_buffer.h"
namespace mindspore {
namespace dataset {
@ -26,37 +24,6 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
// Name: CreateDataBuffer()
// Description: A static factory method to create the appropriate type of derived class
// buffer. Returns the base class reference for DataBuffer.
Status DataBuffer::CreateDataBuffer(
int32_t id, // In: The id for the new buffer
std::shared_ptr<StorageClient> storage_client, // In: The storage client that is related to this buffer type
std::unique_ptr<DataBuffer> *ptr) {
std::unique_ptr<DataBuffer> new_data_buffer;
try {
DatasetType ds_type = storage_client->schema()->dataset_type();
switch (ds_type) {
case DatasetType::kTf: {
// This type of buffer is for TF record data.
// Allocate derived class version for a TF buffers
new_data_buffer = std::make_unique<TFBuffer>(id, kDeBFlagNone, storage_client);
break;
}
default: {
std::string errMsg("Invalid buffer type");
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
} catch (std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what());
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
*ptr = std::move(new_data_buffer);
return Status::OK();
}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
@ -98,7 +65,7 @@ Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32
// Remove me!! Callers should fetch rows via pop
Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const {
if (row_id < tensor_table_->size()) {
if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) {
*ptr = tensor_table_->at(row_id);
} else {
std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id);

Some files were not shown because too many files have changed in this diff Show More