forked from mindspore-Ecosystem/mindspore
!242 sync from mindspore to incubator 0623
Merge pull request !242 from yanghaoran/code_sync_0623
This commit is contained in:
commit
db5df9db60
|
@ -10,6 +10,9 @@
|
|||
[submodule "third_party/protobuf"]
|
||||
path = third_party/protobuf
|
||||
url = https://github.com/protocolbuffers/protobuf.git
|
||||
[submodule "akg"]
|
||||
path = akg
|
||||
url = https://gitee.com/mindspore/akg.git
|
||||
[submodule "graphengine"]
|
||||
path = graphengine
|
||||
url = https://gitee.com/ms-incubator/graphengine.git
|
||||
|
|
|
@ -7,7 +7,7 @@ endif ()
|
|||
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
|
||||
if (ENABLE_GE)
|
||||
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
endif ()
|
||||
|
||||
|
@ -86,10 +86,18 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
|
|||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
endif()
|
||||
|
||||
if (ENABLE_AKG AND ENABLE_D)
|
||||
add_subdirectory("${CMAKE_SOURCE_DIR}/akg")
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
add_subdirectory(mindspore/ccsrc)
|
||||
if (ENABLE_TESTCASES)
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
include(cmake/package.cmake)
|
||||
if (ENABLE_SERVING)
|
||||
add_subdirectory(serving)
|
||||
endif()
|
||||
|
||||
include(cmake/package.cmake)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
* DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset.
|
||||
* DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark.
|
||||
* Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset.
|
||||
* SSD: a single stage object detection methods on COCO 2017 dataset.
|
||||
* GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset.
|
||||
* Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset.
|
||||
* Frontend and User Interface
|
||||
|
@ -62,7 +63,7 @@
|
|||
## Contributors
|
||||
Thanks goes to these wonderful people:
|
||||
|
||||
Alexey Shevlyakov, Amir Lashkari, anthony, baihuawei, biffex, buxue, caifubi, candanzg, caojian05, Cathy Wong, changzherui, chenfei, chengxianbin, chenhaozhe, chenzomi, chujinjin, cristoval, dengwentao, eric, etone-chan, fary86, gaojing, gengdongjie, gongchen, guohongzilong, guozhijian, heleiwang, hesham, He Wei, Hoai Linh Tran h00472437, hongxing, huangdongrun, huanghui, Jamie Nisbet, Jesse Lee, jiangjinsheng, jiangzhiwen, jinyaohui, jjfeing, jonwe, jonyguo, Junhan Hu, Kang, kingfo, kswang, laiyongqiang, leopz, lichenever, lihongkang, limingqi107, liubuyu, liuliyan2, liuwenhao4, liuxiao, liuxiao, liyong, lizhenyu, lvliang, Margaret_wangrui, meixiaowei, ms_yan, Nat Sutyanyong, ougongchang, panfengfeng, panyifeng, Peilin Wang, peixu_ren, qianlong, rick_sanchez, seatea, sheng, shijianning, simson, sunsuodong, Tinazhang, VectorSL, wandongdong, wangcong, wanghua, wangnan39, Wei Luning, wenchunjiang, wilfChen, WilliamLian, wsc, wukesong, wuxuejian, Xiaoda Zhang, xiefangqi, xulei2020, Yang, yangjie159, yangruoqi713, yangyongjie, yangzhenzhang, Yanjun Peng, yanzhenxiang2020, yao_yf, Yi Huaijie, yoonlee666, yujianfeng, YuJianfeng, yvetteliu, z00478463, zhangdengcheng, Zhang Qinghua, zhangz0911gm, zhaojichen, zhaoting, zhaozhenlong, zhoufeng, zhouneng, zhousiyi, zhouyuanshen, Zirui Wu, Ziyan, zjun, ZPaC, lihongzhang
|
||||
Alexey Shevlyakov, Amir Lashkari, anthony, baihuawei, biffex, buxue, caifubi, candanzg, caojian05, Cathy Wong, changzherui, chenfei, chengxianbin, chenhaozhe, chenzomi, chujinjin, cristoval, dengwentao, eric, etone-chan, fary86, gaojing, gengdongjie, gongchen, guohongzilong, guozhijian, heleiwang, hesham, He Wei, Hoai Linh Tran, hongxing, huangdongrun, huanghui, Jamie Nisbet, Jesse Lee, jiangjinsheng, jiangzhiwen, jinyaohui, jjfeing, jonwe, jonyguo, Junhan Hu, Kang, kingfo, kswang, laiyongqiang, leopz, lichenever, lihongkang, limingqi107, liubuyu, liuliyan2, liuwenhao4, liuxiao, liuxiao, liyong, lizhenyu, lvliang, Margaret_wangrui, meixiaowei, ms_yan, Nat Sutyanyong, ougongchang, panfengfeng, panyifeng, Peilin Wang, peixu_ren, qianlong, rick_sanchez, seatea, sheng, shijianning, simson, sunsuodong, Tinazhang, VectorSL, wandongdong, wangcong, wanghua, wangnan39, Wei Luning, wenchunjiang, wilfChen, WilliamLian, wsc, wukesong, wuxuejian, Xiaoda Zhang, xiefangqi, xulei2020, Yang, yangjie159, yangruoqi713, yangyongjie, yangzhenzhang, Yanjun Peng, yanzhenxiang2020, yao_yf, Yi Huaijie, yoonlee666, yujianfeng, YuJianfeng, yvetteliu, zhangdengcheng, Zhang Qinghua, zhangz0911gm, zhaojichen, zhaoting, zhaozhenlong, zhoufeng, zhouneng, zhousiyi, zhouyuanshen, Zirui Wu, Ziyan, zjun, ZPaC, lihongzhang
|
||||
|
||||
Contributions of any kind are welcome!
|
||||
|
||||
|
|
|
@ -2245,14 +2245,14 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
Please also refer to the file CONTRIBUTING.md, which clarifies licensing of
|
||||
external contributions to this project including patches, pull requests, etc.
|
||||
|
||||
Software: SQLite 3.31.1
|
||||
Software: SQLite 3.32.2
|
||||
Copyright notice:
|
||||
Copyright 2008 D. Richard Hipp and Hipp, Wyrick & Company, Inc.
|
||||
Copyright (C) 1996, 1997, 1998, 1999, 2000, 2001, 2003, 2004, 2005, 2006, 2007 2008 Free Software Foundation, Inc.
|
||||
(c) The page number is greater than the largest page that existed in
|
||||
Copyright (c) 1991-2011 Unicode, Inc.
|
||||
Copyright 2008 D. Richard Hipp and Hipp, Wyrick & Company, Inc.
|
||||
Copyright (c) 2002 by David Gravereaux.
|
||||
Copyright (c) 2006 by Pat Thoyts
|
||||
(c) The page number is greater than the largest page that existed in
|
||||
Copyright (C) 1996, 1997, 1998, 1999, 2000, 2001, 2003, 2004, 2005, 2006, 2007 2008 Free Software Foundation, Inc.
|
||||
|
||||
License: Public Domain
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means.
|
||||
|
@ -3053,6 +3053,646 @@ Copyright 2003 Google Inc.
|
|||
Copyright 2009 Google Inc.
|
||||
Copyright (C) 1991-2, RSA Data Security, Inc. Created 1991. All
|
||||
|
||||
Software: tinyxml2 8.0.0
|
||||
Copyright 2011, John Resig.
|
||||
Copyright 2011, The Dojo Foundation.
|
||||
|
||||
Software: icu 67.1
|
||||
Copyright (C) 2000-2004, International Business Machines Corporation
|
||||
Copyright (C) 2002-2014, International Business Machines(C) Copyright IBM Corp. 1998-2011 - All Rights Reserved
|
||||
Copyright (C) 2003-2008, International Business Machines
|
||||
Copyright (C) 2005-2006, International Business Machines
|
||||
Copyright (C) 2016 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2001-2010 International Business Machines
|
||||
Copyright (C) 2009, International Business Machines
|
||||
Copyright (c) 2010-2015 International Business Machines Corporation and others. All rights reserved.
|
||||
Copyright (C) 2002-2015, International Business Machines verbatim (minus copyright and #include) and copied together into this file.
|
||||
Copyright (c) 1997-2014, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 1997-2008, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2003, International Business Machines Corporation and
|
||||
Copyright (c) 1996-2012, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2016, International Business Machines
|
||||
Copyright (c) 1997-2013 International Business Machines
|
||||
Copyright (c) 1997-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2001, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2012, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2005, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2010, International Business Machines Corporation and
|
||||
Copyright (c) 2011-2016, International Business Machines Corporation
|
||||
Copyright (c) 1997-2009, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2002,2008, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2009,2014, International Business Machines
|
||||
Copyright (C) 2000-2009, International Business Machines
|
||||
Copyright (c) 1997-2015, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2013, International Business Machines Corporation and
|
||||
Copyright (c) 2001-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2016, International Business Machines Corporation
|
||||
Copyright (c) 1997-2003, 2007-2009 International Business Machines Corporation and
|
||||
Copyright (c) 2011-2014, International Business Machines Corporation
|
||||
Copyright (c) 2003-2009, International Business Machines
|
||||
Copyright (c) 2016, International Business Machines Corporation
|
||||
Copyright (c) 1997-2004, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2016, International Business Machines
|
||||
Copyright (C) 1998-2014, International Business Machines Corporation
|
||||
Copyright (c) 2003-2013, International Business Machines Corporation and
|
||||
Copyright (c) 2005-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1999-2013, International Business Machines Corporation and
|
||||
Copyright (c) 2003-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2016, International Business Machines
|
||||
Copyright (C) 2003-2014, International Business Machines
|
||||
Copyright (C) 2003, International Business Machines
|
||||
Copyright (c) 1998-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2004-2015, International Business Machines Corporation and
|
||||
Copyright (c) 2009-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2012, International Business Machines
|
||||
Copyright (c) 2000-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2014, International Business Machines
|
||||
Copyright (C) 2001-2016, International Business Machines
|
||||
Copyright (c) 1997-2014, International Business Machines © 2017 and later: Unicode, Inc. and others.
|
||||
Copyright (C) 2007-2016, International Business Machines © 2018 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2015, International Business Machines Corporation
|
||||
Copyright (c) 2014-2016, International Business Machines Corporation
|
||||
Copyright (c) 2002-2016, International Business Machines
|
||||
Copyright (c) 2001-2011,2015 International Business Machines
|
||||
Copyright (c) 2001-2016 International Business Machines
|
||||
Copyright (c) 2005-2013, International Business Machines Corporation and
|
||||
Copyright (c) 1998-2014, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2016 International Business Machines
|
||||
Copyright (C) 2009-2014, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2014, International Business Machines Corporation
|
||||
Copyright (c) 2002-2007, International Business Machines Corporation
|
||||
Copyright (C) 1996-2012, International Business Machines Corporation
|
||||
Copyright (C) 1996-2008, International Business Machines Corporation
|
||||
Copyright (C) 2007-2013, International Business Machines Corporation and
|
||||
Copyright (C) 2008-2015, International Business Machines
|
||||
Copyright (C) 2003-2013, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2013, International Business Machines Corporation
|
||||
Copyright (C) 1997-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2011, International Business Machines
|
||||
Copyright (C) 2001-2008, International Business Machines
|
||||
Copyright (C) 2003 - 2009, International Business Machines Corporation and
|
||||
Copyright (C) 2003 - 2008, International Business Machines Corporation and
|
||||
Copyright (C) 2007-2014, International Business Machines Corporation
|
||||
Copyright (C) 2007-2013, International Business Machines Corporation
|
||||
Copyright (C) 1997-2013, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2010-2014, International Business Machines
|
||||
Copyright (C) 2010-2015, International Business Machines
|
||||
Copyright (C) 2013-2014, International Business Machines
|
||||
Copyright (C) 1996-2015, International Business Machines
|
||||
Copyright (C) 1996-2014, International Business Machines
|
||||
Copyright (C) 2012-2015, International Business Machines
|
||||
Copyright (C) 2012-2014, International Business Machines
|
||||
Copyright (C) 2013-2015, International Business Machines
|
||||
Copyright (C) 2013-2016, International Business Machines
|
||||
Copyright (C) 1999-2016, International Business Machines
|
||||
Copyright (C) 1999-2015, International Business Machines
|
||||
Copyright (C) 1999-2014, International Business Machines
|
||||
Copyright (C) 2015-2016, International Business Machines Corporation and others.
|
||||
Copyright (C) 2003 - 2013, International Business Machines Corporation and
|
||||
Copyright (C) 1999-2011, International Business Machines
|
||||
Copyright (C) 2005-2016, International Business Machines
|
||||
Copyright (C) 2005-2012, International Business Machines
|
||||
Copyright (C) 2005-2015, International Business Machines
|
||||
Copyright (C) 2005-2013, International Business Machines
|
||||
Copyright (C) 2005-2014, International Business Machines
|
||||
Copyright (c) 2004, International Business Machines
|
||||
Copyright (c) 2004-2014 International Business Machines
|
||||
Copyright (c) 2004-2014, International Business Machines
|
||||
Copyright (C) 2013, International Business Machines Corporation
|
||||
Copyright (C) 1997-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2016, International Business Machines
|
||||
Copyright (c) IBM Corporation, 2000-2012. All rights reserved.
|
||||
Copyright (c) IBM Corporation, 2000-2011. All rights reserved.
|
||||
Copyright (c) IBM Corporation, 2000-2014. All rights reserved.
|
||||
Copyright (c) IBM Corporation, 2000-2010. All rights reserved.
|
||||
Copyright (c) IBM Corporation, 2000-2016. All rights reserved.
|
||||
Copyright 2010 the V8 project authors. All rights reserved.
|
||||
Copyright 2006-2008 the V8 project authors. All rights reserved.
|
||||
Copyright 2012 the V8 project authors. All rights reserved.
|
||||
Copyright (C) 2008-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2007-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2007-2012, International Business Machines Corporation and
|
||||
Copyright (c) 2001-2011, International Business Machines
|
||||
Copyright (c) 2001-2007, International Business Machines
|
||||
Copyright (C) 2010-2014, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2010, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2012, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2012, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2012, International Business Machines Corporation
|
||||
Copyright (c) 2002-2011, International Business Machines Corporation
|
||||
Copyright (C) 2008-2013, International Business Machines Corporation and
|
||||
Copyright (c) 2003-2008, International Business Machines
|
||||
Copyright (C) 2003-2016, International Business Machines Corporation
|
||||
Copyright (C) 2003-2014, International Business Machines Corporation
|
||||
Copyright (C) 2003-2008, International Business Machines Corporation
|
||||
Copyright (C) 2005-2008, International Business Machines
|
||||
Copyright (C) 2003-2015, International Business Machines Corporation
|
||||
Copyright (C) 2003-2009,2012,2016 International Business Machines Corporation and
|
||||
Copyright (c) 2004-2016, International Business Machines © 2020 and later: Unicode, Inc. and others.
|
||||
Copyright (C) 2007-2008, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2007, International Business Machines
|
||||
Copyright (C) 1997-2012, International Business Machines
|
||||
Copyright (C) 1997-2015, International Business Machines
|
||||
Copyright (C) 2001-2010, International Business Machines
|
||||
Copyright (c) 2000-2005, International Business Machines
|
||||
Copyright (c) 2000-2007, International Business Machines © 2019 and later: Unicode, Inc. and others.
|
||||
Copyright (C) 2010-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2015, International Business Machines Corporation and
|
||||
Copyright (c) 2003-2013, International Business Machines
|
||||
Copyright (C) 2001-2012, International Business Machines
|
||||
Copyright (C) 2001-2011, International Business Machines Corporation
|
||||
Copyright (C) 2014-2016, International Business Machines
|
||||
Copyright (C) 1997-2015, International Business Machines Corporation
|
||||
Copyright (C) 1999-2007, International Business Machines
|
||||
Copyright (C) 1999-2007, International Business Machines Corporation
|
||||
Copyright (C) 1999-2011, International Business Machines Corporation
|
||||
Copyright (C) {1999-2001}, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2002-2016 International Business Machines Corporation and others.
|
||||
Copyright (C) 2002-2016, International Business Machines Corporation and others.
|
||||
Copyright (C) 2002-2016 International Business Machines Corporation
|
||||
Copyright (C) 2002-2015, International Business Machines Corporation and others.
|
||||
Copyright (C) 2012 International Business Machines Corporation
|
||||
Copyright (C) 2002-2015 International Business Machines Corporation
|
||||
Copyright (C) 2004-2015, International Business Machines Corporation and others.
|
||||
Copyright (C) 2003-2010, International Business Machines Corporation and others.
|
||||
Copyright (c) 2008-2011, International Business Machines Corporation and
|
||||
Copyright (c) 2008-2010, International Business Machines Corporation and
|
||||
Copyright (C) 2014-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2013, International Business Machines Corporation and
|
||||
Copyright (c) 2014, International Business Machines
|
||||
Copyright (C) 2014, International Business Machines
|
||||
Copyright (C) 2013, International Business Machines
|
||||
Copyright (C) 2001-2008,2010 IBM and others. All rights reserved.
|
||||
Copyright (C) 2010 , Yahoo! Inc.
|
||||
Copyright (c) 1997-2011, International Business Machines Corporation and
|
||||
Copyright (C) 2013-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2013, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2012, International Business Machines Corporation and
|
||||
Copyright (C) 2015, International Business Machines Corporation
|
||||
Copyright (c) 2001-2012, International Business Machines Corporation
|
||||
Copyright (C) 2001-2014 IBM and others. All rights reserved.
|
||||
Copyright (C) 2008-2014, Google, International Business Machines Corporation and
|
||||
Copyright (C) 2008, Google, International Business Machines Corporation and
|
||||
Copyright (C) 2008-2015, Google, International Business Machines Corporation
|
||||
Copyright (c) 2001-2014, International Business Machines
|
||||
Copyright (c) 2002-2010, International Business Machines Corporation
|
||||
Copyright (C) 2011-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2011-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2011-2012, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2016, International Business Machines
|
||||
Copyright (C) 1998-2014, International Business Machines
|
||||
Copyright (C) 2004-2016, International Business Machines
|
||||
Copyright (C) 2010-2011, International Business Machines
|
||||
Copyright (C) 2009-2015, International Business Machines
|
||||
Copyright (C) 2015, International Business Machines
|
||||
Copyright (C) 2012-2016, International Business Machines
|
||||
Copyright (C) 1999-2012, International Business Machines
|
||||
Copyright (C) 2001, International Business Machines
|
||||
Copyright (C) 2013, International Business Machines Corporation and others.
|
||||
Copyright (C) 2010-2012, International Business Machines
|
||||
Copyright (C) 2004-2015, International Business Machines
|
||||
Copyright (C) 2003-2006, International Business Machines
|
||||
Copyright (C) 2013-2015, International Business Machines Corporation and others.
|
||||
Copyright (C) 2001-2015 IBM and others. All rights reserved.
|
||||
Copyright (C) 2008-2015, International Business Machines Corporation
|
||||
Copyright (C) 2008-2016, International Business Machines
|
||||
Copyright (C) 2008-2013, International Business Machines Corporation
|
||||
Copyright (C) 2004-2012, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2009,2014 International Business Machines
|
||||
Copyright (C) 2009-2011, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2013, International Business Machines
|
||||
Copyright (C) 2008-2011, International Business Machines
|
||||
Copyright (C) 2007-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2010, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2016 International Business Machines Corporation
|
||||
Copyright (c) 2002-2011, International Business Machines
|
||||
Copyright (C) 2001-2012 IBM, Inc. All Rights Reserved.
|
||||
Copyright (c) 2013-2016 International Business Machines Corporation and others. All rights reserved.
|
||||
Copyright (c) 2013-2015 International Business Machines Corporation and others. All rights reserved.
|
||||
Copyright (c) 2007-2012, International Business Machines Corporation and
|
||||
Copyright (c) 2007-2012, International Business Machines
|
||||
Copyright (C) 2010, International Business Machines
|
||||
Copyright (C) 1997-2011, International Business Machines
|
||||
Copyright (C) 1997-2005, International Business Machines
|
||||
Copyright (C) 2009-2011, International Business Machines
|
||||
Copyright (C) 2003-2015, International Business Machines
|
||||
Copyright (C) 2009-2016, International Business Machines
|
||||
Copyright (C) 2008-2012, International Business Machines
|
||||
Copyright (C) 2008, International Business Machines
|
||||
Copyright (C) 2011-2014, International Business Machines
|
||||
Copyright (C) 2011-2013, International Business Machines
|
||||
Copyright (C) 2005, International Business Machines
|
||||
Copyright (C) 1999-2013, International Business Machines
|
||||
Copyright (C) 1998-2016, International Business Machines
|
||||
Copyright (c) 2007-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2013, International Business Machines
|
||||
Copyright (c) 2007-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2008-2015, International Business Machines
|
||||
Copyright (C) 1999-2010, International Business Machines
|
||||
Copyright (C) 2000-2015, International Business Machines
|
||||
Copyright (C) 2000-2011, International Business Machines
|
||||
Copyright (C) 2000-2012, International Business Machines
|
||||
Copyright (C) 2000-2010, International Business Machines
|
||||
Copyright (C) 2004-2010, International Business Machines
|
||||
Copyright (C) 2004-2005, International Business Machines
|
||||
Copyright (c) 2013-2014, International Business Machines
|
||||
Copyright (c) 1991-2013 Unicode, Inc. © 2019 Unicode®, Inc.
|
||||
Copyright (C) 2018 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2008-2013 International Business Machines
|
||||
Copyright (C) 2002-2010, International Business Machines
|
||||
Copyright (c) 2012-2015 International Business Machines © 2020 Unicode®, Inc.
|
||||
Copyright (c) 2005-2013 IBM Corporation and others. All rights reserved
|
||||
Copyright (c) 2011-2012, International Business Machines Corporation and
|
||||
Copyright (C) 1998-2000, International Business Machines © 2017 Unicode®, Inc.
|
||||
Copyright (c) 2007-2015 International Business Machines
|
||||
Copyright (C) 2004-2006, International Business Machines
|
||||
Copyright (C) 2003-2005, International Business Machines
|
||||
Copyright (c) 1999-2014 International Business Machines
|
||||
Copyright (c) 2003, International Business Machines
|
||||
Copyright (C) 2014 International Business Machines
|
||||
Copyright (c) 2001-2003 International Business Machines
|
||||
Copyright (c) 2004-2011 International Business Machines
|
||||
Copyright (C) 2015-2016, International Business Machines
|
||||
Copyright (c) 2001-2015 International Business Machines
|
||||
Copyright (C) 2003-2012, International Business Machines Corporation and COPYRIGHT AND PERMISSION NOTICE
|
||||
Copyright (c) 2003 National Electronics and Computer Technology Center and others
|
||||
Copyright (C) 2005-2010, International Business Machines
|
||||
Copyright (c) 2007-2009 IBM Corporation and others. All rights reserved
|
||||
Copyright (C) 2004-2016 International Business Machines
|
||||
Copyright (C) 1998-2013, International Business Machines
|
||||
Copyright (C) 1998-2010, International Business Machines
|
||||
Copyright (c) 1999-2004, International Business Machines
|
||||
Copyright (C) 2002-2006 International Business Machines Corporation
|
||||
Copyright (C) 1999-2006, International Business Machines
|
||||
Copyright (C) 2002-2016 IBM, Inc. All Rights Reserved.
|
||||
Copyright (c) 2002-2006, International Business Machines(C) Copyright IBM Corp. 1998-2007 - All Rights Reserved
|
||||
Copyright (C) 1999-2003, International Business Machines
|
||||
Copyright (C) 1998-2006, International Business Machines Corporation and
|
||||
Copyright (C) 1998-2003, International Business Machines Corporation and
|
||||
Copyright (C) 2003 - 2008, International Business Machines
|
||||
Copyright (C) 1999-2008, International Business Machines
|
||||
Copyright (C) 1999-2001, International Business Machines
|
||||
Copyright (C) 1999-2005, International Business Machines
|
||||
Copyright (C) 2016 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2001-2010 IBM Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 1998-2005, International Business Machines Corporation and
|
||||
Copyright (C) 1998-2001, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2005, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2000-2014, International Business Machines
|
||||
Copyright (C) 1996-2013, International Business Machines
|
||||
Copyright (c) 2002-2006, International Business Machines Corporation and
|
||||
Copyright (c) 2004-2010, International Business Machines Corporation and
|
||||
Copyright (C) 2004-2011, International Business Machines
|
||||
Copyright (c) 2002-2005, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2014, International Business Machines
|
||||
Copyright (c) 1997-2012, International Business Machines
|
||||
Copyright (c) 2002-2008, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2011-2013, Apple Inc.; Unicode, Inc.; and others. All Rights Reserved.
|
||||
Copyright (C) 2011-2013, Apple Inc. and others. All Rights Reserved.
|
||||
Copyright (c) 2005-2007,2010 Apple Inc., Unicode Inc.,and others. All Rights Reserved.
|
||||
Copyright (c) 1999-2003, International Business Machines Corporation and
|
||||
Copyright (c) 2003-2014, International Business Machines
|
||||
Copyright (c) 2002-2010, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 1999-2010, International Business Machines Corporation and
|
||||
Copyright (c) 1999-2002, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2003, International Business Machines
|
||||
Copyright (C) 2002, International Business Machines
|
||||
Copyright (c) 2007, International Business Machines Corporation and
|
||||
Copyright (C) 2007, International Business Machines
|
||||
Copyright (C) 2001-2006, International Business Machines
|
||||
Copyright (C) 2010-2014, International Business Machines Corporation and others.
|
||||
Copyright (C) 2005-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2015-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2008-2012, International Business Machines Corporation
|
||||
Copyright (c) 2006-2015 International Business Machines Corporation and others. All rights reserved.
|
||||
Copyright (c) 2014-2015 International Business Machines Corporation and others. All rights reserved.
|
||||
Copyright (C) 2002-2011, International Business Machines
|
||||
Copyright (c) 2003-2010, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2012 IBM Corporation and Others. All Rights Reserved.
|
||||
Copyright (C) 1998-2012, International Business Machines Corporation
|
||||
Copyright (c) 2009, International Business Machines Corporation and
|
||||
Copyright (C) The Internet Society (2002). All Rights Reserved.
|
||||
Copyright (c) 2015, International Business Machines Corporation and
|
||||
Copyright (c) 2002, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 1998-2016, International Business Machines Corporation
|
||||
Copyright (c) 2011-2016,International Business Machines
|
||||
Copyright (C) 2012 International Business Machines Corporation and Others. All Rights Reserved.
|
||||
Copyright (C) 2011, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2011, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 2011-2012,International Business Machines
|
||||
Copyright (c) 2007, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2007-2007, International Business Machines(C) Copyright IBM Corp. 1998-2014 - All Rights Reserved
|
||||
Copyright (C) 1998-2002, International Business Machines
|
||||
Copyright (c) 2001-2007, International Business Machines Corporation and others. All Rights Reserved.(C) Copyright IBM Corp. 1998-2013 - All Rights Reserved
|
||||
Copyright (C) 1998-2015, International Business Machines
|
||||
Copyright (C) 2001-2014 International Business Machines
|
||||
Copyright (C) 2011-2016, International Business Machines
|
||||
Copyright (C) 2011-2015, International Business Machines
|
||||
Copyright (c) 1999-2014, International Business Machines Corporation and
|
||||
Copyright (c) 1999-2009, International Business Machines Corporation and
|
||||
Copyright (c) 2010,International Business Machines
|
||||
Copyright (c) 2010-2016,International Business Machines
|
||||
Copyright (c) 2002-2005, International Business Machines
|
||||
Copyright (C) 2000-2003, International Business Machines
|
||||
Copyright (c) 2008-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2001 - 2005, International Business Machines
|
||||
Copyright (C) 2001-2005, International Business Machines
|
||||
Copyright (C) 1995-2014, International Business Machines
|
||||
Copyright (c) 2000-2004 IBM, Inc. and Others.
|
||||
Copyright (c) 2002-2014, International Business Machines Corporation and
|
||||
Copyright (c) 2007-2013, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2012, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2012, International Business Machines
|
||||
Copyright (C) 2009-2011, International Business Machines Corporation, Google and Others.
|
||||
Copyright (c) 2002, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2009-2014, International Business Machines
|
||||
Copyright (C) 2008, International Business Machines Corporation and others.
|
||||
Copyright (C) 2000-2016, International Business Machines
|
||||
Copyright (C) 2011-2014 International Business Machines
|
||||
Copyright (C) 1997-2014, International Business Machines
|
||||
Copyright (C) 1997-2013, International Business Machines
|
||||
Copyright (c) 2004-2006, International Business Machines
|
||||
Copyright (C) 1997-2016, International Business Machines
|
||||
Copyright (C) 1997-2006, International Business Machines
|
||||
Copyright (C) 1997-2011, International Business Machines Corporation and others.
|
||||
Copyright (C) 1997-2013, International Business Machines Corporation and others.
|
||||
Copyright (c) 2004-2015, International Business Machines
|
||||
Copyright (C) 2009-2017, International Business Machines Corporation,Google, and others. All Rights Reserved.
|
||||
Copyright (C) 1997-2016, International Business Machines Corporation and others.
|
||||
Copyright (C) 2008-2015, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2015, International Business Machines Corporation and others.
|
||||
Copyright (C) 2014-2016, International Business Machines Corporation and others.
|
||||
Copyright (c) 2014-2016, International Business Machines
|
||||
Copyright (C) 2001-2011 IBM and others. All rights reserved.
|
||||
Copyright (C) 1996-2014, International Business Machines Corporation and others.
|
||||
Copyright (C) 1996-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2016, International Business Machines Corporation,
|
||||
Copyright (C) 2009-2010, Google, International Business Machines Corporation and
|
||||
Copyright (C) 2008-2014, Google, International Business Machines Corporation
|
||||
Copyright (C) 1996-2015, International Business Machines Corporation and
|
||||
Copyright (c) 1996-2015, International Business Machines Corporation and others.
|
||||
Copyright (C) 2010-2012,2015 International Business Machines
|
||||
Copyright (C) 2007-2015, International Business Machines
|
||||
Copyright (C) 2013-2014, International Business Machines Corporation and others.
|
||||
Copyright (C) 2010-2013, International Business Machines
|
||||
Copyright (c) 2002-2005, International Business Machines Corporation
|
||||
Copyright (C) 2001-2011,2014 IBM and others. All rights reserved.
|
||||
Copyright (C) 2008-2016, International Business Machines Corporation
|
||||
Copyright (C) 2004 - 2008, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2011,2014-2015 International Business Machines
|
||||
Copyright (C) 2001-2003, International Business Machines
|
||||
Copyright (C) 1999-2009, International Business Machines
|
||||
Copyright (C) 2020 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2002, International Business Machines Corporation and
|
||||
Copyright (C) 2000-2008, International Business Machines
|
||||
Copyright (C) 1998-2006, International Business Machines
|
||||
Copyright (C) 1998-2001, International Business Machines Corporation
|
||||
Copyright (C) 1998-2004, International Business Machines Corporation
|
||||
Copyright (C) 2000, International Business Machines
|
||||
Copyright (c) 1999-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2015, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 1999-2012, International Business Machines Corporation and
|
||||
Copyright (C) 1998-2011, International Business Machines
|
||||
Copyright (C) 2008-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2004, International Business Machines
|
||||
Copyright (c) 2003-2005, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2002-2006 IBM, Inc. All Rights Reserved.
|
||||
Copyright (C) 2004-2008, International Business Machines
|
||||
Copyright (c) 2002-2016 International Business Machines Corporation and
|
||||
Copyright (c) 2002-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2016, International Business Machines Corporation
|
||||
Copyright (c) 2002-2010,International Business Machines
|
||||
Copyright (c) 2002-2014,International Business Machines
|
||||
Copyright (c) 2002-2016,International Business Machines
|
||||
Copyright (C) 2016 International Business Machines Corporation
|
||||
Copyright © 2019 and later: Unicode, Inc. and others.
|
||||
Copyright (c) 2016, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 2016 International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 2015-2016, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 2005-2006, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2004, International Business Machines Corporation
|
||||
Copyright (c) 2012-2016, International Business Machines Corporation
|
||||
Copyright (c) 2012-2014, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2014, International Business Machines Corporation
|
||||
Copyright (c) 1996-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2003-2013, International Business Machines Corporation
|
||||
Copyright (c) 2003-2008, International Business Machines Corporation
|
||||
Copyright (c) 1997-2015, International Business Machines Corporation
|
||||
Copyright (c) 2002-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2002, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2012, International Business Machines
|
||||
Copyright (c) 1997-2013 International Business Machines Corporation and
|
||||
Copyright (c) 2010-2012, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2011, International Business Machines Corporation
|
||||
Copyright (c) 1997-2006, International Business Machines Corporation and
|
||||
Copyright (c) 2008-2016 International Business Machines Corporation and
|
||||
Copyright (c) 2008-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2016 International Business Machines Corporation and
|
||||
Copyright (c) 2007-2011, International Business Machines
|
||||
Copyright (c) 2007-2010, International Business Machines
|
||||
Copyright (C) 2001-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2003, International Business Machines Corporation and
|
||||
Copyright (C) 2003-2011, International Business Machines
|
||||
Copyright (c) 1997-2007, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2015, International Business Machines
|
||||
Copyright (C) 2004-2009, International Business Machines Corporation and
|
||||
Copyright (C) 2004, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2009, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2006, International Business Machines Corporation and
|
||||
Copyright (C) 2011-2013, International Business Machines Corporation
|
||||
Copyright (C) 2000-2007, International Business Machines
|
||||
Copyright (c) 2001, International Business Machines Corporation and
|
||||
Copyright (C) 2012-2013, International Business Machines
|
||||
Copyright (c) 2010-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2010-2016, International Business Machines Corporation
|
||||
Copyright (c) 1997-2010, International Business Machines Corporation
|
||||
Copyright (c) 1997-2003, International Business Machines
|
||||
Copyright (C) 2014-2015, International Business Machines Corporation and
|
||||
Copyright (c) 1997-2013, International Business Machines Corporation
|
||||
Copyright (c) 1999-2016, International Business Machines
|
||||
Copyright (c) 1999-2016 International Business Machines Corporation and
|
||||
Copyright (c) 2016, International Business Machines Corporation and
|
||||
Copyright (c) 2016, International Business Machines
|
||||
Copyright (c) 2013-2016, International Business Machines Corporation
|
||||
Copyright (c) 2013, International Business Machines Corporation
|
||||
Copyright (C) 2013-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2001-2010, International Business Machines Corporation and
|
||||
Copyright (C) 2014, International Business Machines Corporation and
|
||||
Copyright (c) 1999-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2016, International Business Machines orporation
|
||||
Copyright (c) 2001-2008, International Business Machines Corporation and others
|
||||
Copyright (C) 2003-2016, International Business Machines Corporation and
|
||||
Copyright (c) 2004, International Business Machines Corporation
|
||||
Copyright (C) 2001-2009, International Business Machines
|
||||
Copyright (c) 2004,2011 International Business Machines
|
||||
Copyright (c) 2004-2011, International Business Machines
|
||||
Copyright (c) 2000-2016, International Business Machines Corporation
|
||||
Copyright (c) 2001-2005, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2004, International Business Machines
|
||||
Copyright (c) 2001-2009, International Business Machines
|
||||
Copyright (c) 1997-2009, International Business Machines Corporation
|
||||
Copyright (c) 1997-2013, International Business Machines
|
||||
Copyright (c) 1997-2012, International Business Machines Corporation
|
||||
Copyright (C) 2007-2015, International Business Machines Corporation and
|
||||
Copyright (C) 2007-2011, International Business Machines Corporation and
|
||||
Copyright (C) 2007, International Business Machines Corporation and
|
||||
Copyright (c) 1998-2005, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2010, International Business Machines Corporation and
|
||||
Copyright (C) 1999-2016 International Business Machines Corporation and
|
||||
Copyright (c) 2004-2011, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2007, International Business Machines Corporation and
|
||||
Copyright (C) 2003, International Business Machines Corporation and
|
||||
Copyright (C) 2005-2011, International Business Machines
|
||||
Copyright (C) 2011-2012, International Business Machines
|
||||
Copyright (C) 2007-2012, International Business Machines
|
||||
Copyright (C) 2006-2016, International Business Machines Corporation
|
||||
Copyright (C) 2006-2012, International Business Machines Corporation and others.
|
||||
Copyright 2007 Google Inc. All Rights Reserved.
|
||||
Copyright (c) 2001-2015, International Business Machines
|
||||
Copyright (C) 2006-2014, International Business Machines Corporation
|
||||
Copyright (C) 2008, International Business Machines Corporation and
|
||||
Copyright (C) 2009-2012, International Business Machines
|
||||
Copyright (C) 2006 International Business Machines Corporation
|
||||
Copyright (C) 2010-2016, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2002-2005, International Business Machines Corporation and
|
||||
Copyright (C) 2011, International Business Machines
|
||||
Copyright (c) 2003-2010 International Business Machines
|
||||
Copyright (C) 2003-2003, International Business Machines
|
||||
Copyright (C) 1999-2016 International Business Machines Corporation
|
||||
Copyright (C) 1999-2014 International Business Machines Corporation
|
||||
Copyright (C) 1999-2014 International Business Machines
|
||||
Copyright (C) 2002-2011, International Business Machines Corporation and others.
|
||||
Copyright (C) 2002-2008, International Business Machines Corporation and others.
|
||||
Copyright (C) 2002-2008 International Business Machines Corporation
|
||||
Copyright (c) 2001-2005, International Business Machines
|
||||
Copyright (C) 2002-2014 International Business Machines Corporation
|
||||
Copyright (c) 2003-2011, International Business Machines
|
||||
Copyright (C) 1998-2012, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2014, International Business Machines Corporation.
|
||||
Copyright (C) 2001-2011, International Business Machines Corporation.
|
||||
Copyright (C) 2001-2014, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2011, International Business Machines Corporation and
|
||||
Copyright (C) 2001-2012, International Business Machines Corporation and
|
||||
Copyright 2004 and onwards Google Inc.
|
||||
Copyright (C) 2004-2014, International Business Machines
|
||||
Copyright (C) 2006, International Business Machines
|
||||
Copyright (C) 2004-2012, International Business Machines
|
||||
Copyright (C) 2001-2013, International Business Machines
|
||||
Copyright (C) 1998-2004, International Business Machines
|
||||
Copyright (C) 2000-2013, International Business Machines
|
||||
Copyright (C) 1999-2015 International Business Machines
|
||||
Copyright (C) 2000-2006, International Business Machines
|
||||
Copyright (C) 1999-2004, International Business Machines
|
||||
Copyright (C) 2003-2007, International Business Machines
|
||||
Copyright (C) 2002-2006, International Business Machines
|
||||
Copyright (C) 2001-2015, International Business Machines
|
||||
Copyright (c) 2001-2012, International Business Machines
|
||||
Copyright (c) 2002-2004, International Business Machines
|
||||
Copyright (C) 1999-2016, International Business Machines Corporation and
|
||||
Copyright (c) 1996-2014, International Business Machines
|
||||
Copyright (C) 1999-2016, International Business Machines Corporation
|
||||
Copyright (C) 2009-2014 International Business Machines
|
||||
Copyright (C) 2004-2007, International Business Machines
|
||||
Copyright (c) 2001-2016, International Business Machines
|
||||
Copyright (C) 2003-2009, International Business Machines
|
||||
Copyright (C) 1999-2013, International Business Machines Corporation and
|
||||
Copyright (C) 1999-2015, International Business Machines Corporation and
|
||||
Copyright (c) 2002-2011, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 2001-2016 IBM, Inc. All Rights Reserved.
|
||||
Copyright (C) 1999-2016 International Business Machines
|
||||
Copyright (C) 2009-2010 IBM Corporation and Others. All Rights Reserved.
|
||||
Copyright (C) 1998-2012, International Business Machines
|
||||
Copyright (C) 1991 and later: Unicode, Inc. and others.
|
||||
Copyright (C) 1997-2000, International Business Machines
|
||||
Copyright (c) 1999-2007, International Business Machines Corporation and
|
||||
Copyright (c) 2000 IBM, Inc. and Others.
|
||||
Copyright (C) 2008-2013, International Business Machines
|
||||
Copyright (C) 1998-2003, 2006, International Business Machines Corporation
|
||||
Copyright (c) 2002-2003,International Business Machines
|
||||
Copyright (C) 2009 International Business Machines
|
||||
Copyright (C) 2010-2016 International Business Machines
|
||||
Copyright (C) 2008-2012 IBM, Inc. All Rights Reserved.
|
||||
Copyright (C) 1998-2008, International Business Machines
|
||||
Copyright (C) 2010-2016, International Business Machines
|
||||
Copyright (C) 1999-2006,2013 IBM Corp. All rights reserved.
|
||||
Copyright (C) 2008-2009, International Business Machines Corporation and
|
||||
Copyright (C) 2012,2014 International Business Machines
|
||||
Copyright (c) 1996-2015, International Business Machines Corporation and
|
||||
Copyright (C) 1997-2005, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 1999-2012, International Business Machines Corporation and
|
||||
Copyright (C) 1996-2013, International Business Machines Corporation
|
||||
Copyright (C) 1998-2005, International Business Machines
|
||||
Copyright 2001 and onwards Google Inc.
|
||||
Copyright (C) 2010-2012,2014, International Business Machines
|
||||
Copyright (C) 1996-2015, International Business Machines Corporation and others.
|
||||
Copyright (c) 2003-2004, International Business Machines
|
||||
Copyright (C) 2000-2004, International Business Machines
|
||||
Copyright (C) 2002-2013, International Business Machines
|
||||
Copyright (C) 2002-2011 International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (C) 1999-2010, International Business Machines Corporation and others.
|
||||
Copyright (C) 2001-2005, International Business Machines Corporation and others. All Rights Reserved.
|
||||
Copyright (c) 1996-2016, International Business Machines Corporation
|
||||
Copyright (C) 1997-2010, International Business Machines
|
||||
|
||||
Software: libtiff 4.1.0
|
||||
Copyright notice:
|
||||
Copyright © 2015 Open Microscopy Environment / University of Dundee
|
||||
Copyright (c) 2004, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) 1990-1997 Sam Leffler
|
||||
Copyright (c) 1991-1997 Silicon Graphics, Inc.
|
||||
Copyright (c) 1988-1997 Sam Leffler
|
||||
Copyright (c) 1991-1997 Sam Leffler
|
||||
Use and Copyright
|
||||
Copyright (C) 1990, 1995 Frank D. Cringle.
|
||||
Copyright (c) 1994-1997 Sam Leffler
|
||||
Copyright (c) 1994-1997 Silicon Graphics, Inc.
|
||||
Copyright (c) 1997 Greg Ward Larson
|
||||
Copyright (c) 1997 Silicon Graphics, Inc.
|
||||
Copyright (c) 2010, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) Joris Van Damme <info@awaresystems.be>
|
||||
Copyright (c) AWare Systems <http:www.awaresystems.be/>
|
||||
Copyright (c) 1996-1997 Sam Leffler
|
||||
Copyright (c) 1996 Pixar
|
||||
Copyright (c) 1995-1997 Sam Leffler
|
||||
Copyright (c) 1995-1997 Silicon Graphics, Inc.
|
||||
Copyright (c) 1988-1996 Sam Leffler
|
||||
Copyright (c) 1991-1996 Silicon Graphics, Inc.
|
||||
Copyright (c) 1992-1997 Sam Leffler
|
||||
Copyright (c) 1992-1997 Silicon Graphics, Inc.
|
||||
Copyright (c) 2018, Mapbox
|
||||
Copyright (c) 2017, Planet Labs
|
||||
Copyright (c) 1990 by Sun Microsystems, Inc.
|
||||
Copyright 1990 by Digital Equipment Corporation, Maynard, Massachusetts.
|
||||
Copyright 1991 by Digital Equipment Corporation, Maynard, Massachusetts.
|
||||
Copyright (c) 2002, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) 2003 Ross Finlayson
|
||||
Additions (c) Richard Nolde 2006-2010
|
||||
Copyright (c) 2003, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) 2000, Frank Warmerdam
|
||||
Copyright (c) 1987, 1993, 1994
|
||||
Copyright (c) 1989, 1993
|
||||
Copyright (c) 2009 Frank Warmerdam
|
||||
Copyright (c) 1987, 1993
|
||||
Copyright (c) 2005 The DragonFly Project. All rights reserved.
|
||||
Copyright (c) 2003 Citrus Project,
|
||||
All rights reserved.
|
||||
Copyright (c) 1990, 1993
|
||||
Copyright (c) 1996 Mike Johnson
|
||||
Copyright (c) 1996 BancTec AB
|
||||
Copyright (c) 2004, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) 2012, Frank Warmerdam <warmerdam@pobox.com>
|
||||
Copyright (c) 2019, Even Rouault <even.rouault at spatialys.com>
|
||||
Copyright (c) 2007, Frank Warmerdam <warmerdam@pobox.com>
|
||||
Copyright (c) 2019, Thomas Bernard <miniupnp@free.fr>
|
||||
Copyright (c) 2008, Andrey Kiselev <dron@ak4719.spb.edu>
|
||||
Copyright (c) 1999, Frank Warmerdam
|
||||
Copyright (c) 1991-1996 Sam Leffler
|
||||
Copyright (c) 1996 USAF Phillips Laboratory
|
||||
|
||||
Software: opencv 4.2.0
|
||||
Copyright notice:
|
||||
Copyright (C) 2016, NVIDIA Corporation, all rights reserved.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit c460176523d039c8995f1d71089753725ebc0792
|
27
build.sh
27
build.sh
|
@ -49,10 +49,11 @@ usage()
|
|||
echo " -Q Enable dump memory, default off"
|
||||
echo " -D Enable dumping of function graph ir, default on"
|
||||
echo " -z Compile dataset & mindrecord, default on"
|
||||
echo " -M Enable MPI and NCCL for GPU training, default on"
|
||||
echo " -M Enable MPI and NCCL for GPU training, gpu default on"
|
||||
echo " -V Specify the minimum required cuda version, default CUDA 9.2"
|
||||
echo " -I Compile predict, default off"
|
||||
echo " -K Compile with AKG, default off"
|
||||
echo " -s Enable serving module, default off"
|
||||
}
|
||||
|
||||
# check value of input is 'on' or 'off'
|
||||
|
@ -86,15 +87,15 @@ checkopts()
|
|||
ENABLE_DUMPE2E="off"
|
||||
ENABLE_DUMP_IR="on"
|
||||
COMPILE_MINDDATA="on"
|
||||
ENABLE_MPI="on"
|
||||
ENABLE_MPI="off"
|
||||
CUDA_VERSION="9.2"
|
||||
COMPILE_PREDICT="off"
|
||||
USE_GLOG="on"
|
||||
PREDICT_PLATFORM=""
|
||||
ENABLE_AKG="off"
|
||||
|
||||
ENABLE_AKG="on"
|
||||
ENABLE_SERVING="off"
|
||||
# Process the options
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K' opt
|
||||
while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:s' opt
|
||||
do
|
||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||
case "${opt}" in
|
||||
|
@ -168,6 +169,7 @@ checkopts()
|
|||
if [[ "X$OPTARG" == "Xgpu" ]]; then
|
||||
ENABLE_GPU="on"
|
||||
ENABLE_CPU="on"
|
||||
ENABLE_MPI="on"
|
||||
elif [[ "X$OPTARG" == "Xd" || "X$OPTARG" == "Xascend" ]]; then
|
||||
ENABLE_D="on"
|
||||
ENABLE_CPU="on"
|
||||
|
@ -234,6 +236,10 @@ checkopts()
|
|||
ENABLE_AKG="on"
|
||||
echo "enable compile with akg"
|
||||
;;
|
||||
s)
|
||||
ENABLE_SERVING="on"
|
||||
echo "enable serving"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option ${opt}!"
|
||||
usage
|
||||
|
@ -242,9 +248,12 @@ checkopts()
|
|||
done
|
||||
}
|
||||
checkopts "$@"
|
||||
echo "---------------- mindspore: build start ----------------"
|
||||
echo "---------------- MindSpore: build start ----------------"
|
||||
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
|
||||
git submodule update --init graphengine
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
git submodule update --init --recursive akg
|
||||
fi
|
||||
|
||||
build_exit()
|
||||
{
|
||||
|
@ -307,9 +316,13 @@ build_mindspore()
|
|||
if [[ "X$USE_GLOG" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
|
||||
fi
|
||||
if [[ "X$ENABLE_SERVING" = "Xon" ]]; then
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SERVING=ON"
|
||||
fi
|
||||
|
||||
echo "${CMAKE_ARGS}"
|
||||
if [[ "X$INC_BUILD" = "Xoff" ]]; then
|
||||
cmake ${CMAKE_ARGS} ../..
|
||||
|
|
|
@ -36,6 +36,7 @@ elseif (DEFINED ENV{D_LINK_PATH})
|
|||
find_library(hccl libhccl.so ${GE_LIB_PATH})
|
||||
find_library(cce libcce.so ${GE_LIB_PATH})
|
||||
find_library(resource libresource.so ${GE_LIB_PATH})
|
||||
find_library(error_manager liberror_manager.so ${GE_LIB_PATH})
|
||||
else()
|
||||
# Ascend mode
|
||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
|
@ -54,6 +55,7 @@ else()
|
|||
find_library(msprof libmsprof.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(register libregister.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(resource libresource.so ${ASCEND_RUNTIME_PATH})
|
||||
find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_PATH})
|
||||
endif()
|
||||
|
||||
# compile libraries from following directories
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||
set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
mindspore_add_pkg(gtest
|
||||
VER 1.8.0
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
set(LIB_ICU_COMMON icuuc)
|
||||
set(LIB_ICU_DATA icudata)
|
||||
set(LIB_ICU_I18N icui18n)
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
message("icu4c thirdparty do not support windows currently.")
|
||||
else()
|
||||
mindspore_add_pkg(icu4c
|
||||
VER 67.1
|
||||
LIBS ${LIB_ICU_COMMON} ${LIB_ICU_DATA} ${LIB_ICU_I18N}
|
||||
URL https://github.com/unicode-org/icu/archive/release-67-1.tar.gz
|
||||
MD5 0c2662a2b0bc80b0eb56495205247c8f
|
||||
CONFIGURE_COMMAND ./icu4c/source/runConfigureICU Linux --enable-rpath --disable-tests --disable-samples --disable-icuio --disable-extras ICU_DATA_FILTER_FILE=${CMAKE_SOURCE_DIR}/third_party/icu4c/filter.json
|
||||
)
|
||||
include_directories(${icu4c_INC})
|
||||
add_library(mindspore::icuuc ALIAS icu4c::${LIB_ICU_COMMON})
|
||||
add_library(mindspore::icudata ALIAS icu4c::${LIB_ICU_DATA})
|
||||
add_library(mindspore::icui18n ALIAS icu4c::${LIB_ICU_I18N})
|
||||
add_definitions(-D ENABLE_ICU4C)
|
||||
endif()
|
|
@ -8,7 +8,7 @@ elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
|||
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-attributes -Wno-unknown-pragmas")
|
||||
set(opencv_CXXFLAGS "${opencv_CXXFLAGS} -Wno-unused-value -Wno-implicit-fallthrough")
|
||||
else()
|
||||
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(opencv_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||
set(opencv_CFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -D_FORTIFY_SOURCE=2 -O2")
|
||||
set(opencv_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
endif()
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
set(protobuf_USE_STATIC_LIBS ON)
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
||||
else()
|
||||
elseif (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
||||
else()
|
||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
|
||||
endif()
|
||||
|
||||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS})
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
if (WIN32)
|
||||
mindspore_add_pkg(sqlite
|
||||
VER 3.31.1
|
||||
VER 3.32.2
|
||||
LIBS sqlite3
|
||||
URL https://sqlite.org/2020/sqlite-amalgamation-3310100.zip
|
||||
MD5 2b7bfcdd97dc281903a9aee966213fe4
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch001 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch002 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch003
|
||||
URL https://sqlite.org/2020/sqlite-amalgamation-3320200.zip
|
||||
MD5 1eccea18d248eb34c7378b2b3f63f1db
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.windows.patch001
|
||||
CMAKE_OPTION " "
|
||||
)
|
||||
|
||||
|
@ -18,11 +18,11 @@ else ()
|
|||
endif()
|
||||
set(sqlite_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
|
||||
mindspore_add_pkg(sqlite
|
||||
VER 3.31.1
|
||||
VER 3.32.2
|
||||
LIBS sqlite3
|
||||
URL https://github.com/sqlite/sqlite/archive/version-3.31.1.tar.gz
|
||||
MD5 5f4e7b4016c15f4fb5855615279819da
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch002 ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch003
|
||||
URL https://github.com/sqlite/sqlite/archive/version-3.32.2.tar.gz
|
||||
MD5 ea6d3b3289b4ac216fb06081a01ef101
|
||||
PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/sqlite/sqlite.patch001
|
||||
CONFIGURE_COMMAND ./configure --enable-shared=no --disable-tcl --disable-editline --enable-json1)
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@ include_directories(${Python3_INCLUDE_DIRS})
|
|||
include_directories(${CMAKE_SOURCE_DIR}/third_party)
|
||||
if (ENABLE_CPU)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/mkl_dnn.cmake)
|
||||
if (ENABLE_MPI)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
|
@ -36,7 +39,6 @@ if (ENABLE_GPU)
|
|||
|
||||
if (ENABLE_MPI)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ompi.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -52,6 +54,7 @@ elseif(ENABLE_D OR ENABLE_TESTCASES)
|
|||
endif()
|
||||
|
||||
if (ENABLE_MINDDATA)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/icu4c.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/jpeg_turbo.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/libtiff.cmake)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/opencv.cmake)
|
||||
|
|
|
@ -91,7 +91,20 @@ if (ENABLE_MINDDATA)
|
|||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
message("icu4c does not support windows system temporarily")
|
||||
else()
|
||||
file(GLOB_RECURSE ICU4C_LIB_LIST
|
||||
${icu4c_LIBPATH}/libicuuc*
|
||||
${icu4c_LIBPATH}/libicudata*
|
||||
${icu4c_LIBPATH}/libicui18n*
|
||||
)
|
||||
install(
|
||||
FILES ${ICU4C_LIB_LIST}
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_CPU)
|
||||
|
@ -109,19 +122,20 @@ if (ENABLE_CPU)
|
|||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
install(
|
||||
TARGETS _ms_mpi
|
||||
DESTINATION ${INSTALL_BASE_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (ENABLE_GPU)
|
||||
if (ENABLE_MPI)
|
||||
install(
|
||||
TARGETS _ms_mpi
|
||||
DESTINATION ${INSTALL_BASE_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
install(
|
||||
TARGETS gpu_collective
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
install(
|
||||
TARGETS gpu_queue
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
|
@ -222,6 +236,16 @@ if (ENABLE_GPU)
|
|||
endif ()
|
||||
endif ()
|
||||
|
||||
if (ENABLE_D AND ENABLE_AKG)
|
||||
set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg)
|
||||
install(
|
||||
DIRECTORY
|
||||
${AKG_PATH}/akg
|
||||
DESTINATION ${INSTALL_PY_DIR}/..
|
||||
COMPONENT mindspore
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
|
||||
install(
|
||||
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset
|
||||
|
|
|
@ -51,7 +51,7 @@ endif ()
|
|||
# get git commit id
|
||||
set(GIT_COMMIT_ID "")
|
||||
execute_process(
|
||||
COMMAND ${GIT} log --format='[sha1]:%h,[branch]:%d' -1
|
||||
COMMAND ${GIT} log --format='[sha1]:%h,[branch]:%d' --abbrev=8 -1
|
||||
OUTPUT_VARIABLE GIT_COMMIT_ID
|
||||
WORKING_DIRECTORY ${MS_ROOT_DIR}
|
||||
ERROR_QUIET)
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
# Googlenet Example
|
||||
|
||||
## Description
|
||||
|
||||
This example is for Googlenet model training and evaluation.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the CIFAR-10 binary version dataset.
|
||||
|
||||
> Unzip the CIFAR-10 dataset to any path you want and the folder structure should be as follows:
|
||||
> ```
|
||||
> .
|
||||
> ├── cifar-10-batches-bin # train dataset
|
||||
> └── cifar-10-verify-bin # infer dataset
|
||||
> ```
|
||||
|
||||
## Running the Example
|
||||
|
||||
### Training
|
||||
|
||||
```
|
||||
python train.py --data_path=your_data_path --device_id=6 > out.train.log 2>&1 &
|
||||
```
|
||||
The python command above will run in the background, you can view the results through the file `out.train.log`.
|
||||
|
||||
After training, you'll get some checkpoint files under the script folder by default.
|
||||
|
||||
You will get the loss value as following:
|
||||
```
|
||||
# grep "loss is " out.train.log
|
||||
epoch: 1 step: 390, loss is 1.4842823
|
||||
epcoh: 2 step: 390, loss is 1.0897788
|
||||
...
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```
|
||||
python eval.py --data_path=your_data_path --device_id=6 --checkpoint_path=./train_googlenet_cifar10-125-390.ckpt > out.eval.log 2>&1 &
|
||||
```
|
||||
The above python command will run in the background, you can view the results through the file `out.eval.log`.
|
||||
|
||||
You will get the accuracy as following:
|
||||
```
|
||||
# grep "result: " out.eval.log
|
||||
result: {'acc': 0.934}
|
||||
```
|
||||
|
||||
### Distribute Training
|
||||
```
|
||||
sh run_distribute_train.sh rank_table.json your_data_path
|
||||
```
|
||||
The above shell script will run distribute training in the background, you can view the results through the file `train_parallel[X]/log`.
|
||||
|
||||
You will get the loss value as following:
|
||||
```
|
||||
# grep "result: " train_parallel*/log
|
||||
train_parallel0/log:epoch: 1 step: 48, loss is 1.4302931
|
||||
train_parallel0/log:epcoh: 2 step: 48, loss is 1.4023874
|
||||
...
|
||||
train_parallel1/log:epoch: 1 step: 48, loss is 1.3458025
|
||||
train_parallel1/log:epcoh: 2 step: 48, loss is 1.3729336
|
||||
...
|
||||
...
|
||||
```
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
|
||||
## Usage:
|
||||
|
||||
### Training
|
||||
```
|
||||
usage: train.py [--device_target TARGET][--data_path DATA_PATH]
|
||||
[--device_id DEVICE_ID]
|
||||
|
||||
parameters/options:
|
||||
--device_target the training backend type, default is Ascend.
|
||||
--data_path the storage path of dataset
|
||||
--device_id the device which used to train model.
|
||||
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```
|
||||
usage: eval.py [--device_target TARGET][--data_path DATA_PATH]
|
||||
[--device_id DEVICE_ID][--checkpoint_path CKPT_PATH]
|
||||
|
||||
parameters/options:
|
||||
--device_target the evaluation backend type, default is Ascend.
|
||||
--data_path the storage path of datasetd
|
||||
--device_id the device which used to evaluate model.
|
||||
--checkpoint_path the checkpoint file path used to evaluate model.
|
||||
```
|
||||
|
||||
### Distribute Training
|
||||
|
||||
```
|
||||
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]
|
||||
|
||||
parameters/options:
|
||||
MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path.
|
||||
DATA_PATH the storage path of dataset.
|
||||
```
|
|
@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
|
|||
|
||||
1. Download and prepare the Cora dataset as required.
|
||||
|
||||
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
|
||||
|
||||
|
||||
2. Edit write_cora.sh and modify the parameters
|
||||
```
|
||||
--mindrecord_file: output MindRecord file.
|
||||
|
|
|
@ -15,29 +15,27 @@
|
|||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
from mindspore import log as logger
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CITESEER_CONTENT_FILE = args[0]
|
||||
CITESEER_CITES_FILE = args[1]
|
||||
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
|
||||
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
CITESEER_PATH = args[0]
|
||||
dataset_str = 'citeseer'
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
node_ids = []
|
||||
|
||||
|
||||
def _normalize_citeseer_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
|
@ -46,6 +44,14 @@ def _normalize_citeseer_features(features):
|
|||
return features
|
||||
|
||||
|
||||
def _parse_index_file(filename):
|
||||
"""Parse index file."""
|
||||
index = []
|
||||
for line in open(filename):
|
||||
index.append(int(line.strip()))
|
||||
return index
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
@ -53,30 +59,47 @@ def yield_nodes(task_id=0):
|
|||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CITESEER_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter='\t')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
# print('label types {}.'.format(label_types))
|
||||
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
logger.info("Node task is {}".format(task_id))
|
||||
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
|
||||
objects = []
|
||||
for name in names:
|
||||
with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f:
|
||||
objects.append(pkl.load(f, encoding='latin1'))
|
||||
x, y, tx, ty, allx, ally = tuple(objects)
|
||||
test_idx_reorder = _parse_index_file(
|
||||
"{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str))
|
||||
test_idx_range = np.sort(test_idx_reorder)
|
||||
|
||||
tx = _normalize_citeseer_features(tx)
|
||||
allx = _normalize_citeseer_features(allx)
|
||||
|
||||
# Fix citeseer dataset (there are some isolated nodes in the graph)
|
||||
# Find isolated nodes, add them as zero-vecs into the right position
|
||||
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
|
||||
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
|
||||
tx_extended[test_idx_range-min(test_idx_range), :] = tx
|
||||
tx = tx_extended
|
||||
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
|
||||
ty_extended[test_idx_range-min(test_idx_range), :] = ty
|
||||
ty = ty_extended
|
||||
|
||||
features = sp.vstack((allx, tx)).tolil()
|
||||
features[test_idx_reorder, :] = features[test_idx_range, :]
|
||||
features = features.A
|
||||
|
||||
labels = np.vstack((ally, ty))
|
||||
labels[test_idx_reorder, :] = labels[test_idx_range, :]
|
||||
|
||||
line_count = 0
|
||||
for i, label in enumerate(labels):
|
||||
if not 1 in label.tolist():
|
||||
continue
|
||||
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
|
||||
'feature_2': label.tolist().index(1)}
|
||||
line_count += 1
|
||||
node_ids.append(i)
|
||||
yield node
|
||||
logger.info('Processed {} lines for nodes.'.format(line_count))
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
|
@ -86,24 +109,21 @@ def yield_edges(task_id=0):
|
|||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
# print(map_string_int)
|
||||
with open(CITESEER_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter='\t')
|
||||
logger.info("Edge task is {}".format(task_id))
|
||||
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
|
||||
graph = pkl.load(f, encoding='latin1')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
line_count += 1
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
|
||||
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
for i in graph:
|
||||
for dst_id in graph[i]:
|
||||
if not i in node_ids:
|
||||
logger.info('Source node {} does not exist.'.format(i))
|
||||
continue
|
||||
if not dst_id in node_ids:
|
||||
logger.info('Destination node {} does not exist.'.format(
|
||||
dst_id))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||
line_count += 1
|
||||
yield edge
|
||||
logger.info('Processed {} lines for edges.'.format(line_count))
|
||||
|
|
|
@ -15,29 +15,24 @@
|
|||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CORA_CONTENT_FILE = args[0]
|
||||
CORA_CITES_FILE = args[1]
|
||||
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
|
||||
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
CORA_PATH = args[0]
|
||||
dataset_str = 'cora'
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def _normalize_cora_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
|
@ -46,6 +41,14 @@ def _normalize_cora_features(features):
|
|||
return features
|
||||
|
||||
|
||||
def _parse_index_file(filename):
|
||||
"""Parse index file."""
|
||||
index = []
|
||||
for line in open(filename):
|
||||
index.append(int(line.strip()))
|
||||
return index
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CORA_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if line_count == 0:
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
|
||||
names = ['tx', 'ty', 'allx', 'ally']
|
||||
objects = []
|
||||
for name in names:
|
||||
with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f:
|
||||
objects.append(pkl.load(f, encoding='latin1'))
|
||||
tx, ty, allx, ally = tuple(objects)
|
||||
test_idx_reorder = _parse_index_file(
|
||||
"{}/ind.{}.test.index".format(CORA_PATH, dataset_str))
|
||||
test_idx_range = np.sort(test_idx_reorder)
|
||||
|
||||
features = sp.vstack((allx, tx)).tolil()
|
||||
features[test_idx_reorder, :] = features[test_idx_range, :]
|
||||
features = _normalize_cora_features(features)
|
||||
features = features.A
|
||||
|
||||
labels = np.vstack((ally, ty))
|
||||
labels[test_idx_reorder, :] = labels[test_idx_range, :]
|
||||
|
||||
line_count = 0
|
||||
for i, label in enumerate(labels):
|
||||
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
|
||||
'feature_2': label.tolist().index(1)}
|
||||
line_count += 1
|
||||
yield node
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
print('label types {}.'.format(label_types))
|
||||
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
|
@ -90,24 +93,13 @@ def yield_edges(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
with open(CORA_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter=',')
|
||||
with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f:
|
||||
graph = pkl.load(f, encoding='latin1')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if line_count == 0:
|
||||
for i in graph:
|
||||
for dst_id in graph[i]:
|
||||
edge = {'id': line_count,
|
||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
line_count += 1
|
||||
yield edge
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Graph data convert tool for MindRecord.
|
||||
"""
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
|
||||
__all__ = ['GraphMapSchema']
|
||||
|
||||
|
@ -41,6 +42,7 @@ class GraphMapSchema:
|
|||
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
@property
|
||||
def get_schema(self):
|
||||
"""
|
||||
Get schema
|
||||
|
@ -52,6 +54,7 @@ class GraphMapSchema:
|
|||
Set node features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
logger.info("Node feature profile is not match.")
|
||||
raise ValueError("Node feature profile is not match.")
|
||||
|
||||
self.num_node_features = num_features
|
||||
|
@ -66,6 +69,7 @@ class GraphMapSchema:
|
|||
Set edge features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
logger.info("Edge feature profile is not match.")
|
||||
raise ValueError("Edge feature profile is not match.")
|
||||
|
||||
self.num_edge_features = num_features
|
||||
|
@ -83,6 +87,10 @@ class GraphMapSchema:
|
|||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
if node is None:
|
||||
logger.info("node cannot be None.")
|
||||
raise ValueError("node cannot be None.")
|
||||
|
||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
||||
"node_feature_index": []}
|
||||
for i in range(self.num_node_features):
|
||||
|
@ -117,6 +125,10 @@ class GraphMapSchema:
|
|||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
if edge is None:
|
||||
logger.info("edge cannot be None.")
|
||||
raise ValueError("edge cannot be None.")
|
||||
|
||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
||||
"type": edge["type"], "edge_feature_index": []}
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
|
||||
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
|
||||
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
|
||||
[348, 347], [347, 351], [347, 327], [347, 329], [347, 331],
|
||||
[347, 335], [347, 341], [347, 345], [347, 346], [346, 335],
|
||||
[346, 340], [346, 339], [346, 349], [346, 353], [346, 354],
|
||||
[346, 341], [346, 345], [345, 335], [345, 336], [345, 341],
|
||||
[344, 338], [344, 342], [343, 332], [343, 338], [343, 342],
|
||||
[342, 332], [340, 349], [334, 349], [333, 349], [330, 349],
|
||||
[328, 349], [359, 349], [358, 352], [358, 349], [358, 354],
|
||||
[358, 356], [357, 350], [357, 354], [357, 356], [356, 350],
|
||||
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (0, [], [])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
node_list = []
|
||||
for edge in social_data:
|
||||
src, dst = edge
|
||||
if src not in node_list:
|
||||
node_list.append(src)
|
||||
if dst not in node_list:
|
||||
node_list.append(dst)
|
||||
node_list.sort()
|
||||
print(node_list)
|
||||
for node_id in node_list:
|
||||
node = {'id': node_id, 'type': 1}
|
||||
yield node
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
"""
|
||||
Generate edge data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
line_count = 0
|
||||
for undirected_edge in social_data:
|
||||
line_count += 1
|
||||
edge = {
|
||||
'id': line_count,
|
||||
'src_id': undirected_edge[0],
|
||||
'dst_id': undirected_edge[1],
|
||||
'type': 1}
|
||||
yield edge
|
||||
line_count += 1
|
||||
edge = {
|
||||
'id': line_count,
|
||||
'src_id': undirected_edge[1],
|
||||
'dst_id': undirected_edge[0],
|
||||
'type': 1}
|
||||
yield edge
|
|
@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
|
|||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites"
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
|
|
@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
|
|||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv"
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#!/bin/bash
|
||||
MINDRECORD_PATH=/tmp/sns
|
||||
|
||||
rm -f $MINDRECORD_PATH/*
|
||||
|
||||
python writer.py --mindrecord_script sns \
|
||||
--mindrecord_file "$MINDRECORD_PATH/sns" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 14 \
|
||||
--mindrecord_page_size_by_bit 15
|
|
@ -164,7 +164,7 @@ if __name__ == "__main__":
|
|||
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
|
||||
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
|
||||
|
||||
graph_schema = graph_map_schema.get_schema()
|
||||
graph_schema = graph_map_schema.get_schema
|
||||
|
||||
# init writer
|
||||
writer = init_writer(graph_schema)
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Guideline to Convert Training Data CLUERNER2020 to MindRecord For Bert Fine Tuning
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [How to use the example to process CLUERNER2020](#how-to-use-the-example-to-process-cluerner2020)
|
||||
- [Download CLUERNER2020 and unzip](#download-cluerner2020-and-unzip)
|
||||
- [Generate MindRecord](#generate-mindrecord)
|
||||
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example is based on [CLUERNER2020](https://www.cluebenchmarks.com/introduce.html) training data, generating MindRecord file, and finally used for Bert Fine Tuning progress.
|
||||
|
||||
1. run.sh: generate MindRecord entry script
|
||||
2. run_read.py: create MindDataset by MindRecord entry script.
|
||||
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
|
||||
|
||||
## How to use the example to process CLUERNER2020
|
||||
|
||||
Download CLUERNER2020, convert it to MindRecord, use MindDataset to read MindRecord.
|
||||
|
||||
### Download CLUERNER2020 and unzip
|
||||
|
||||
1. Download the training data zip.
|
||||
> [CLUERNER2020 dataset download address](https://www.cluebenchmarks.com/introduce.html) **-> 任务介绍 -> CLUENER 细粒度命名实体识别 -> cluener下载链接**
|
||||
|
||||
2. Unzip the training data to dir example/nlp_to_mindrecord/CLUERNER2020/cluener_public.
|
||||
```
|
||||
unzip -d {your-mindspore}/example/nlp_to_mindrecord/CLUERNER2020/data/cluener_public cluener_public.zip
|
||||
```
|
||||
|
||||
### Generate MindRecord
|
||||
|
||||
1. Run the run.sh script.
|
||||
```bash
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
2. Output like this:
|
||||
```
|
||||
...
|
||||
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:12.498.235 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/train.mindrecord'], and the list of index files are: ['data/train.mindrecord.db']
|
||||
...
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.400.175 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.400.863 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.401.534 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.402.179 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
[INFO] ME(17603,python):2020-04-28-16:56:13.402.702 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
|
||||
...
|
||||
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:13.431.208 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/dev.mindrecord'], and the list of index files are: ['data/dev.mindrecord.db']
|
||||
```
|
||||
|
||||
3. Generate files like this:
|
||||
```bash
|
||||
$ ls output/
|
||||
dev.mindrecord dev.mindrecord.db README.md train.mindrecord train.mindrecord.db
|
||||
```
|
||||
|
||||
### Create MindDataset By MindRecord
|
||||
|
||||
1. Run the run_read.sh script.
|
||||
```bash
|
||||
bash run_read.sh
|
||||
```
|
||||
|
||||
2. Output like this:
|
||||
```
|
||||
...
|
||||
example 1340: input_ids: [ 101 3173 1290 4852 7676 3949 122 3299 123 126 3189 4510 8020 6381 5442 7357 2590 3636 8021 7676 3949 4294 1166 6121 3124 1277 6121 3124 7270 2135 3295 5789 3326 123 126 3189 1355 6134 1093 1325 3173 2399 6590 6791 8024 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1340: label_ids: [ 0 18 19 20 2 4 0 0 0 0 0 0 0 34 36 26 27 28 0 34 35 35 35 35 35 35 35 35 35 36 26 27 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: input_ids: [ 101 1728 711 4293 3868 1168 2190 2150 3791 934 3633 3428 4638 6237 7025 8024 3297 1400 5310 3362 6206 5023 5401 1744 3297 7770 3791 7368 976 1139 1104 2137 511 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 1341: label_ids: [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18 19 19 19 19 20 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
...
|
||||
```
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create MindDataset by MindRecord"""
|
||||
import mindspore.dataset as ds
|
||||
|
||||
def create_dataset(data_file):
|
||||
"""create MindDataset"""
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
|
||||
index = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
# print("example {}: {}".format(index, item))
|
||||
print("example {}: input_ids: {}".format(index, item['input_ids']))
|
||||
print("example {}: input_mask: {}".format(index, item['input_mask']))
|
||||
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
|
||||
print("example {}: label_ids: {}".format(index, item['label_ids']))
|
||||
index += 1
|
||||
if index % 1000 == 0:
|
||||
print("read rows: {}".format(index))
|
||||
print("total rows: {}".format(index))
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_dataset('output/train.mindrecord')
|
||||
create_dataset('output/dev.mindrecord')
|
|
@ -0,0 +1 @@
|
|||
cluener_public
|
|
@ -0,0 +1 @@
|
|||
## The input dataset
|
|
@ -0,0 +1 @@
|
|||
## output dir
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/train.mindrecord*
|
||||
rm -f output/dev.mindrecord*
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/CLUERNER2020" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/CLUERNER2020 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for data_processor_seq.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/CLUERNER2020/ -o data_processor_seq_patched.py < ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# use patched script
|
||||
python ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq_patched.py \
|
||||
--vocab_file=../../../third_party/to_mindrecord/CLUERNER2020/vocab.txt \
|
||||
--label2id_file=../../../third_party/to_mindrecord/CLUERNER2020/label2id.json
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
python create_dataset.py
|
|
@ -0,0 +1,173 @@
|
|||
# Guideline to Convert Training Data enwiki to MindRecord For Bert Pre Training
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [How to use the example to process enwiki](#how-to-use-the-example-to-process-enwiki)
|
||||
- [Download enwiki training data](#download-enwiki-training-data)
|
||||
- [Process the enwiki](#process-the-enwiki)
|
||||
- [Generate MindRecord](#generate-mindrecord)
|
||||
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example is based on [enwiki](https://dumps.wikimedia.org/enwiki) training data, generating MindRecord file, and finally used for Bert network training.
|
||||
|
||||
1. run.sh: generate MindRecord entry script.
|
||||
2. run_read.py: create MindDataset by MindRecord entry script.
|
||||
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
|
||||
|
||||
## How to use the example to process enwiki
|
||||
|
||||
Download enwiki data, process it, convert it to MindRecord, use MindDataset to read MindRecord.
|
||||
|
||||
### Download enwiki training data
|
||||
|
||||
> [enwiki dataset download address](https://dumps.wikimedia.org/enwiki) **-> 20200501 -> enwiki-20200501-pages-articles-multistream.xml.bz2**
|
||||
|
||||
### Process the enwiki
|
||||
|
||||
1. Please follow the steps in [process enwiki](https://github.com/mlperf/training/tree/master/language_model/tensorflow/bert)
|
||||
- All permissions of this step belong to the link address website.
|
||||
|
||||
### Generate MindRecord
|
||||
|
||||
1. Run the run.sh script.
|
||||
```
|
||||
bash run.sh input_dir output_dir vocab_file
|
||||
```
|
||||
- input_dir: the directory which contains files like 'part-00251-of-00500'.
|
||||
- output_dir: which will store the output mindrecord files.
|
||||
- vocab_file: the vocab file which you can download from other opensource project.
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
...
|
||||
Begin preprocess Wed Jun 10 09:21:23 CST 2020
|
||||
Begin preprocess input file: /mnt/data/results/part-00000-of-00500
|
||||
Begin output file: part-00000-of-00500.mindrecord
|
||||
Total task: 510, processing: 1
|
||||
Begin preprocess input file: /mnt/data/results/part-00001-of-00500
|
||||
Begin output file: part-00001-of-00500.mindrecord
|
||||
Total task: 510, processing: 2
|
||||
Begin preprocess input file: /mnt/data/results/part-00002-of-00500
|
||||
Begin output file: part-00002-of-00500.mindrecord
|
||||
Total task: 510, processing: 3
|
||||
Begin preprocess input file: /mnt/data/results/part-00003-of-00500
|
||||
Begin output file: part-00003-of-00500.mindrecord
|
||||
Total task: 510, processing: 4
|
||||
Begin preprocess input file: /mnt/data/results/part-00004-of-00500
|
||||
Begin output file: part-00004-of-00500.mindrecord
|
||||
Total task: 510, processing: 4
|
||||
...
|
||||
```
|
||||
|
||||
3. Generate files like this:
|
||||
```bash
|
||||
$ ls {your_output_dir}/
|
||||
part-00000-of-00500.mindrecord part-00000-of-00500.mindrecord.db part-00001-of-00500.mindrecord part-00001-of-00500.mindrecord.db part-00002-of-00500.mindrecord part-00002-of-00500.mindrecord.db ...
|
||||
```
|
||||
|
||||
### Create MindDataset By MindRecord
|
||||
|
||||
1. Run the run_read.sh script.
|
||||
```bash
|
||||
bash run_read.sh input_dir
|
||||
```
|
||||
- input_dir: the directory which contains mindrecord files.
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
...
|
||||
example 633: input_ids: [ 101 2043 19781 4305 2140 4520 2041 1010 103 2034 2455 2002
|
||||
7879 2003 1996 2455 1997 103 26378 4160 1012 102 7291 2001
|
||||
1996 103 1011 2343 1997 6327 1010 3423 1998 103 4262 2005
|
||||
1996 2118 1997 2329 3996 103 102 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0]
|
||||
example 633: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 633: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
|
||||
example 633: masked_lm_positions: [ 8 17 20 25 33 41 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0]
|
||||
example 633: masked_lm_ids: [ 1996 16137 1012 3580 2451 1012 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0 0 0 0 0 0 0 0 0
|
||||
0 0 0 0]
|
||||
example 633: masked_lm_weights: [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
|
||||
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
|
||||
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
|
||||
0. 0. 0. 0.]
|
||||
example 633: next_sentence_labels: [1]
|
||||
...
|
||||
```
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create MindDataset by MindRecord"""
|
||||
import argparse
|
||||
import mindspore.dataset as ds
|
||||
|
||||
def create_dataset(data_file):
|
||||
"""create MindDataset"""
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
|
||||
index = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
# print("example {}: {}".format(index, item))
|
||||
print("example {}: input_ids: {}".format(index, item['input_ids']))
|
||||
print("example {}: input_mask: {}".format(index, item['input_mask']))
|
||||
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
|
||||
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
|
||||
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
|
||||
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
|
||||
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
|
||||
index += 1
|
||||
if index % 1000 == 0:
|
||||
print("read rows: {}".format(index))
|
||||
print("total rows: {}".format(index))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
|
||||
args = parser.parse_args()
|
||||
|
||||
create_dataset(args.input_file)
|
|
@ -0,0 +1,133 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "Usage: $0 input_dir output_dir vocab_file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $1 ]; then
|
||||
echo "The input dir: $1 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $2 ]; then
|
||||
echo "The output dir: $2 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
rm -fr $2/*.mindrecord*
|
||||
|
||||
if [ ! -f $3 ]; then
|
||||
echo "The vocab file: $3 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
data_dir=$1
|
||||
output_dir=$2
|
||||
vocab_file=$3
|
||||
file_list=()
|
||||
output_filename=()
|
||||
file_index=0
|
||||
|
||||
function getdir() {
|
||||
elements=`ls $1`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
dir_or_file=$1"/"$element
|
||||
if [ -d $dir_or_file ];
|
||||
then
|
||||
getdir $dir_or_file
|
||||
else
|
||||
file_list[$file_index]=$dir_or_file
|
||||
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
|
||||
mapfile parent_dir < dir_file_list.txt
|
||||
rm dir_file_list.txt >/dev/null 2>&1
|
||||
tmp_output_filename=${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
|
||||
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
|
||||
file_index=`expr $file_index + 1`
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
getdir "${data_dir}"
|
||||
# echo "The input files: "${file_list[@]}
|
||||
# echo "The output files: "${output_filename[@]}
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for create_pretraining_data.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# get the cpu core count
|
||||
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
|
||||
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
|
||||
|
||||
echo "Begin preprocess `date`"
|
||||
|
||||
# using patched script to generate mindrecord
|
||||
file_list_len=`expr ${#file_list[*]} - 1`
|
||||
for index in $(seq 0 $file_list_len); do
|
||||
echo "Begin preprocess input file: ${file_list[$index]}"
|
||||
echo "Begin output file: ${output_filename[$index]}"
|
||||
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
|
||||
--input_file=${file_list[$index]} \
|
||||
--output_file=${output_dir}/${output_filename[$index]} \
|
||||
--partition_number=1 \
|
||||
--vocab_file=${vocab_file} \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=512 \
|
||||
--max_predictions_per_seq=76 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 &
|
||||
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
|
||||
if [ $process_count -ge $avaiable_core_size ]; then
|
||||
while [ 1 ]; do
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
if [ $process_count -gt $process_num ]; then
|
||||
process_count=$process_num
|
||||
break;
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
fi
|
||||
done
|
||||
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
while [ 1 ]; do
|
||||
if [ $process_num -eq 0 ]; then
|
||||
break;
|
||||
fi
|
||||
echo "There are still ${process_num} preprocess running ..."
|
||||
sleep 2
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
done
|
||||
|
||||
echo "Preprocess all the data success."
|
||||
echo "End preprocess `date`"
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 input_dir"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $1 ]; then
|
||||
echo "The input dir: $1 is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
file_list=()
|
||||
file_index=0
|
||||
|
||||
# get all the mindrecord file from output dir
|
||||
function getdir() {
|
||||
elements=`ls $1/part-*.mindrecord`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
file_list[$file_index]=$element
|
||||
file_index=`expr $file_index + 1`
|
||||
done
|
||||
}
|
||||
|
||||
getdir $1
|
||||
echo "Get all the mindrecord files: "${file_list[*]}
|
||||
|
||||
# create dataset for train
|
||||
python create_dataset.py --input_file ${file_list[*]}
|
|
@ -0,0 +1,113 @@
|
|||
# Guideline to Convert Training Data zhwiki to MindRecord For Bert Pre Training
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [Run simple test](#run-simple-test)
|
||||
- [How to use the example to process zhwiki](#how-to-use-the-example-to-process-zhwiki)
|
||||
- [Download zhwiki training data](#download-zhwiki-training-data)
|
||||
- [Extract the zhwiki](#extract-the-zhwiki)
|
||||
- [Generate MindRecord](#generate-mindrecord)
|
||||
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example is based on [zhwiki](https://dumps.wikimedia.org/zhwiki) training data, generating MindRecord file, and finally used for Bert network training.
|
||||
|
||||
1. run.sh: generate MindRecord entry script.
|
||||
2. run_read.py: create MindDataset by MindRecord entry script.
|
||||
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
|
||||
|
||||
## Run simple test
|
||||
|
||||
Follow the step:
|
||||
|
||||
```bash
|
||||
bash run_simple.sh # generate output/simple.mindrecord* by ../../../third_party/to_mindrecord/zhwiki/sample_text.txt
|
||||
bash run_read_simple.sh # use MindDataset to read output/simple.mindrecord*
|
||||
```
|
||||
|
||||
## How to use the example to process zhwiki
|
||||
|
||||
Download zhwiki data, extract it, convert it to MindRecord, use MindDataset to read MindRecord.
|
||||
|
||||
### Download zhwiki training data
|
||||
|
||||
> [zhwiki dataset download address](https://dumps.wikimedia.org/zhwiki) **-> 20200401 -> zhwiki-20200401-pages-articles-multistream.xml.bz2**
|
||||
|
||||
- put the zhwiki-20200401-pages-articles-multistream.xml.bz2 in {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
|
||||
|
||||
### Extract the zhwiki
|
||||
|
||||
1. Download [wikiextractor](https://github.com/attardi/wikiextractor) script to {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
|
||||
|
||||
```
|
||||
$ ls data/
|
||||
README.md wikiextractor zhwiki-20200401-pages-articles-multistream.xml.bz2
|
||||
```
|
||||
|
||||
2. Extract the zhwiki.
|
||||
```python
|
||||
python data/wikiextractor/WikiExtractor.py data/zhwiki-20200401-pages-articles-multistream.xml.bz2 --processes 4 --templates data/template --bytes 8M --min_text_length 0 --filter_disambig_pages --output data/extract
|
||||
```
|
||||
|
||||
3. Generate like this:
|
||||
```
|
||||
$ ls data/extract
|
||||
AA AB
|
||||
```
|
||||
|
||||
### Generate MindRecord
|
||||
|
||||
1. Run the run.sh script.
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
> Caution: This process maybe slow, please wait patiently. If you do not have a machine with enough memory and cpu, it is recommended that you modify the script to generate mindrecord in step by step.
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
patching file create_pretraining_data_patched.py (read from create_pretraining_data.py)
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_00
|
||||
Begin output file: AAwiki_00.mindrecord
|
||||
Total task: 5, processing: 1
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_01
|
||||
Begin output file: AAwiki_01.mindrecord
|
||||
Total task: 5, processing: 2
|
||||
Begin preprocess input file: ./data/extract/AA/wiki_02
|
||||
Begin output file: AAwiki_02.mindrecord
|
||||
Total task: 5, processing: 3
|
||||
Begin preprocess input file: ./data/extract/AB/wiki_02
|
||||
Begin output file: ABwiki_02.mindrecord
|
||||
Total task: 5, processing: 4
|
||||
...
|
||||
```
|
||||
|
||||
3. Generate files like this:
|
||||
```bash
|
||||
$ ls output/
|
||||
AAwiki_00.mindrecord AAwiki_00.mindrecord.db AAwiki_01.mindrecord AAwiki_01.mindrecord.db AAwiki_02.mindrecord AAwiki_02.mindrecord.db ... ABwiki_00.mindrecord ABwiki_00.mindrecord.db ...
|
||||
```
|
||||
|
||||
### Create MindDataset By MindRecord
|
||||
|
||||
1. Run the run_read.sh script.
|
||||
```bash
|
||||
bash run_read.sh
|
||||
```
|
||||
|
||||
2. The output like this:
|
||||
```
|
||||
...
|
||||
example 74: input_ids: [ 101 8168 118 12847 8783 9977 15908 117 8256 9245 11643 8168 8847 8588 11575 8154 8228 143 8384 8376 9197 10241 103 10564 11421 8199 12268 112 161 8228 11541 9586 8436 8174 8363 9864 9702 103 103 119 103 9947 10564 103 8436 8806 11479 103 8912 119 103 103 103 12209 8303 103 8757 8824 117 8256 103 8619 8168 11541 102 11684 8196 103 8228 8847 11523 117 9059 9064 12410 8358 8181 10764 117 11167 11706 9920 148 8332 11390 8936 8205 10951 11997 103 8154 117 103 8670 10467 112 161 10951 13139 12413 117 10288 143 10425 8205 152 10795 8472 8196 103 161 12126 9172 13129 12106 8217 8174 12244 8205 143 103 8461 8277 10628 160 8221 119 102]
|
||||
example 74: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
|
||||
example 74: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
|
||||
example 74: masked_lm_positions: [ 6 22 37 38 40 43 47 50 51 52 55 60 67 76 89 92 98 109 120 0]
|
||||
example 74: masked_lm_ids: [ 8118 8165 8329 8890 8554 8458 119 8850 8565 10392 8174 11467 10291 8181 8549 12718 13139 112 158 0]
|
||||
example 74: masked_lm_weights: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
|
||||
example 74: next_sentence_labels: [0]
|
||||
...
|
||||
```
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""create MindDataset by MindRecord"""
|
||||
import argparse
|
||||
import mindspore.dataset as ds
|
||||
|
||||
def create_dataset(data_file):
|
||||
"""create MindDataset"""
|
||||
num_readers = 4
|
||||
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
|
||||
index = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
# print("example {}: {}".format(index, item))
|
||||
print("example {}: input_ids: {}".format(index, item['input_ids']))
|
||||
print("example {}: input_mask: {}".format(index, item['input_mask']))
|
||||
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
|
||||
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
|
||||
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
|
||||
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
|
||||
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
|
||||
index += 1
|
||||
if index % 1000 == 0:
|
||||
print("read rows: {}".format(index))
|
||||
print("total rows: {}".format(index))
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
|
||||
args = parser.parse_args()
|
||||
|
||||
create_dataset(args.input_file)
|
|
@ -0,0 +1,3 @@
|
|||
wikiextractor/
|
||||
zhwiki-20200401-pages-articles-multistream.xml.bz2
|
||||
extract/
|
|
@ -0,0 +1 @@
|
|||
## The input dataset
|
|
@ -0,0 +1 @@
|
|||
## Output the mindrecord
|
|
@ -0,0 +1,112 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/*.mindrecord*
|
||||
|
||||
data_dir="./data/extract"
|
||||
file_list=()
|
||||
output_filename=()
|
||||
file_index=0
|
||||
|
||||
function getdir() {
|
||||
elements=`ls $1`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
dir_or_file=$1"/"$element
|
||||
if [ -d $dir_or_file ];
|
||||
then
|
||||
getdir $dir_or_file
|
||||
else
|
||||
file_list[$file_index]=$dir_or_file
|
||||
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
|
||||
mapfile parent_dir < dir_file_list.txt
|
||||
rm dir_file_list.txt >/dev/null 2>&1
|
||||
tmp_output_filename=${parent_dir[${#parent_dir[@]}-2]}${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
|
||||
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
|
||||
file_index=`expr $file_index + 1`
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
getdir "${data_dir}"
|
||||
# echo "The input files: "${file_list[@]}
|
||||
# echo "The output files: "${output_filename[@]}
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for create_pretraining_data.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# get the cpu core count
|
||||
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
|
||||
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
|
||||
|
||||
echo "Begin preprocess `date`"
|
||||
|
||||
# using patched script to generate mindrecord
|
||||
file_list_len=`expr ${#file_list[*]} - 1`
|
||||
for index in $(seq 0 $file_list_len); do
|
||||
echo "Begin preprocess input file: ${file_list[$index]}"
|
||||
echo "Begin output file: ${output_filename[$index]}"
|
||||
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
|
||||
--input_file=${file_list[$index]} \
|
||||
--output_file=output/${output_filename[$index]} \
|
||||
--partition_number=1 \
|
||||
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
|
||||
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
|
||||
if [ $process_count -ge $avaiable_core_size ]; then
|
||||
while [ 1 ]; do
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
if [ $process_count -gt $process_num ]; then
|
||||
process_count=$process_num
|
||||
break;
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
fi
|
||||
done
|
||||
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
while [ 1 ]; do
|
||||
if [ $process_num -eq 0 ]; then
|
||||
break;
|
||||
fi
|
||||
echo "There are still ${process_num} preprocess running ..."
|
||||
sleep 2
|
||||
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
|
||||
done
|
||||
|
||||
echo "Preprocess all the data success."
|
||||
echo "End preprocess `date`"
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
file_list=()
|
||||
file_index=0
|
||||
|
||||
# get all the mindrecord file from output dir
|
||||
function getdir() {
|
||||
elements=`ls $1/[A-Z]*.mindrecord`
|
||||
for element in ${elements[*]};
|
||||
do
|
||||
file_list[$file_index]=$element
|
||||
file_index=`expr $file_index + 1`
|
||||
done
|
||||
}
|
||||
|
||||
getdir "./output"
|
||||
echo "Get all the mindrecord files: "${file_list[*]}
|
||||
|
||||
# create dataset for train
|
||||
python create_dataset.py --input_file ${file_list[*]}
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# create dataset for train
|
||||
python create_dataset.py --input_file=output/simple.mindrecord0
|
|
@ -0,0 +1,47 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -f output/simple.mindrecord*
|
||||
|
||||
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
|
||||
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
|
||||
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# patch for create_pretraining_data.py
|
||||
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# using patched script to generate mindrecord
|
||||
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
|
||||
--input_file=../../../third_party/to_mindrecord/zhwiki/sample_text.txt \
|
||||
--output_file=output/simple.mindrecord \
|
||||
--partition_number=4 \
|
||||
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
|
||||
--do_lower_case=True \
|
||||
--max_seq_length=128 \
|
||||
--max_predictions_per_seq=20 \
|
||||
--masked_lm_prob=0.15 \
|
||||
--random_seed=12345 \
|
||||
--dupe_factor=10 # user defined
|
|
@ -15,6 +15,7 @@
|
|||
"""train_imagenet."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr
|
||||
from config import config
|
||||
|
@ -45,6 +46,7 @@ if __name__ == '__main__':
|
|||
target = args_opt.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
np.random.seed(1)
|
||||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""train_imagenet."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from dataset import create_dataset
|
||||
from lr_generator import get_lr
|
||||
from config import config
|
||||
|
@ -48,6 +49,7 @@ if __name__ == '__main__':
|
|||
target = args_opt.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
np.random.seed(1)
|
||||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
@ -77,12 +79,12 @@ if __name__ == '__main__':
|
|||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train.dataset_helper import _send_data
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -67,7 +68,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
from model.dataset_helper import DatasetHelper
|
||||
|
@ -374,7 +374,6 @@ class Model:
|
|||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
|
@ -385,17 +384,17 @@ class Model:
|
|||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = list_callback
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
|
@ -408,7 +407,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
iter_first_order = self._frequency - 1
|
||||
|
@ -473,7 +472,7 @@ class Model:
|
|||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (_ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
|
@ -580,7 +579,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -619,7 +618,7 @@ class Model:
|
|||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -678,7 +677,6 @@ class Model:
|
|||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
list_callback = _build_callbacks(callbacks)
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
|
@ -691,9 +689,10 @@ class Model:
|
|||
|
||||
self._clear_metrics()
|
||||
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
|
|
|
@ -151,6 +151,8 @@ class THOR(Optimizer):
|
|||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
|
||||
temp_max = self.mul(temp_max, self.feature_map[i])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
if i == 53:
|
||||
g = self.cube_matmul_left_fc(temp_g, g)
|
||||
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""thor_layer"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||
|
@ -23,7 +25,6 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.ops import operations as P
|
||||
import numpy as np
|
||||
C0 = 16
|
||||
|
||||
def caculate_device_shape(matrix_dim, channel, is_A):
|
||||
|
@ -171,7 +172,6 @@ class Conv2d_Thor(_Conv):
|
|||
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
|
||||
self.fake_G = Tensor(
|
||||
np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape))
|
||||
self.fake_G_inv_max = Tensor(np.zeros([1,]).astype(np.float32))
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -286,7 +286,6 @@ class Conv2d_Thor(_Conv):
|
|||
matrix_A_inv = self.device_shape_pad(matrix_A_inv)
|
||||
matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape)
|
||||
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
|
||||
self.G_inv_max = self.fake_G_inv_max
|
||||
self.matrix_A_inv = matrix_A_inv
|
||||
self.matrix_G_inv = self.fake_G
|
||||
out = self.conv2d(x, self.weight)
|
||||
|
@ -339,15 +338,15 @@ class Dense_Thor(Cell):
|
|||
self.has_bias = check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
|
|
@ -17,6 +17,8 @@ import argparse
|
|||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
|
@ -28,7 +30,6 @@ from model.model_thor import Model
|
|||
from model.resnet import resnet50
|
||||
from model.thor import THOR
|
||||
|
||||
import numpy as np
|
||||
from config import config
|
||||
from crossentropy import CrossEntropy
|
||||
from dataset_imagenet import create_dataset
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
# SSD Example
|
||||
|
||||
## Description
|
||||
|
||||
SSD network based on MobileNetV2, with support for training and evaluation.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Dataset
|
||||
|
||||
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
|
||||
|
||||
1. If coco dataset is used. **Select dataset to coco when run script.**
|
||||
Install Cython and pycocotool.
|
||||
|
||||
```
|
||||
pip install Cython
|
||||
|
||||
pip install pycocotools
|
||||
```
|
||||
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
|
||||
|
||||
|
||||
```
|
||||
└─coco2017
|
||||
├── annotations # annotation jsons
|
||||
├── train2017 # train dataset
|
||||
└── val2017 # infer dataset
|
||||
```
|
||||
|
||||
2. If your own dataset is used. **Select dataset to other when run script.**
|
||||
Organize the dataset infomation into a TXT file, each row in the file is as follows:
|
||||
|
||||
```
|
||||
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
||||
```
|
||||
|
||||
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
|
||||
|
||||
|
||||
- Stand alone mode
|
||||
|
||||
```
|
||||
python train.py --dataset coco
|
||||
|
||||
```
|
||||
|
||||
You can run ```python train.py -h``` to get more information.
|
||||
|
||||
|
||||
- Distribute mode
|
||||
|
||||
```
|
||||
sh run_distribute_train.sh 8 150 coco /data/hccl.json
|
||||
```
|
||||
|
||||
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
|
||||
|
||||
You will get the loss value of each step as following:
|
||||
|
||||
```
|
||||
epoch: 1 step: 455, loss is 5.8653416
|
||||
epoch: 2 step: 455, loss is 5.4292373
|
||||
epoch: 3 step: 455, loss is 5.458992
|
||||
...
|
||||
epoch: 148 step: 455, loss is 1.8340507
|
||||
epoch: 149 step: 455, loss is 2.0876894
|
||||
epoch: 150 step: 455, loss is 2.239692
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
|
||||
|
||||
```
|
||||
python eval.py --ckpt_path ssd.ckpt --dataset coco
|
||||
```
|
||||
|
||||
You can run ```python eval.py -h``` to get more information.
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
|
||||
class ConfigSSD:
|
||||
"""
|
||||
Config parameters for SSD.
|
||||
|
||||
Examples:
|
||||
ConfigSSD().
|
||||
"""
|
||||
IMG_SHAPE = [300, 300]
|
||||
NUM_SSD_BOXES = 1917
|
||||
NEG_PRE_POSITIVE = 3
|
||||
MATCH_THRESHOLD = 0.5
|
||||
|
||||
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
|
||||
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
|
||||
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
|
||||
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
|
||||
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
|
||||
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
|
||||
SCALES = [21, 45, 99, 153, 207, 261, 315]
|
||||
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
|
||||
STEPS = (16, 32, 64, 100, 150, 300)
|
||||
PRIOR_SCALING = (0.1, 0.2)
|
||||
|
||||
|
||||
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
|
||||
MINDRECORD_DIR = "MindRecord_COCO"
|
||||
COCO_ROOT = "coco2017"
|
||||
TRAIN_DATA_TYPE = "train2017"
|
||||
VAL_DATA_TYPE = "val2017"
|
||||
INSTANCES_SET = "annotations/instances_{}.json"
|
||||
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush')
|
||||
NUM_CLASSES = len(COCO_CLASSES)
|
|
@ -1,375 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SSD dataset"""
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import math
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from config import ConfigSSD
|
||||
|
||||
config = ConfigSSD()
|
||||
|
||||
class GeneratDefaultBoxes():
|
||||
"""
|
||||
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
|
||||
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [x, y, w, h].
|
||||
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [x1, y1, x2, y2].
|
||||
"""
|
||||
def __init__(self):
|
||||
fk = config.IMG_SHAPE[0] / np.array(config.STEPS)
|
||||
self.default_boxes = []
|
||||
for idex, feature_size in enumerate(config.FEATURE_SIZE):
|
||||
sk1 = config.SCALES[idex] / config.IMG_SHAPE[0]
|
||||
sk2 = config.SCALES[idex + 1] / config.IMG_SHAPE[0]
|
||||
sk3 = math.sqrt(sk1 * sk2)
|
||||
|
||||
if config.NUM_DEFAULT[idex] == 3:
|
||||
all_sizes = [(0.5, 1.0), (1.0, 1.0), (1.0, 0.5)]
|
||||
else:
|
||||
all_sizes = [(sk1, sk1), (sk3, sk3)]
|
||||
for aspect_ratio in config.ASPECT_RATIOS[idex]:
|
||||
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((w, h))
|
||||
all_sizes.append((h, w))
|
||||
|
||||
assert len(all_sizes) == config.NUM_DEFAULT[idex]
|
||||
|
||||
for i, j in it.product(range(feature_size), repeat=2):
|
||||
for w, h in all_sizes:
|
||||
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
|
||||
box = [np.clip(k, 0, 1) for k in (cx, cy, w, h)]
|
||||
self.default_boxes.append(box)
|
||||
|
||||
def to_ltrb(cx, cy, w, h):
|
||||
return cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
||||
|
||||
# For IoU calculation
|
||||
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
|
||||
self.default_boxes = np.array(self.default_boxes, dtype='float32')
|
||||
|
||||
|
||||
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
|
||||
default_boxes = GeneratDefaultBoxes().default_boxes
|
||||
x1, y1, x2, y2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
|
||||
vol_anchors = (x2 - x1) * (y2 - y1)
|
||||
matching_threshold = config.MATCH_THRESHOLD
|
||||
|
||||
|
||||
def ssd_bboxes_encode(boxes):
|
||||
"""
|
||||
Labels anchors with ground truth inputs.
|
||||
|
||||
Args:
|
||||
boxex: ground truth with shape [N, 5], for each row, it stores [x, y, w, h, cls].
|
||||
|
||||
Returns:
|
||||
gt_loc: location ground truth with shape [num_anchors, 4].
|
||||
gt_label: class ground truth with shape [num_anchors, 1].
|
||||
num_matched_boxes: number of positives in an image.
|
||||
"""
|
||||
|
||||
def jaccard_with_anchors(bbox):
|
||||
"""Compute jaccard score a box and the anchors."""
|
||||
# Intersection bbox and volume.
|
||||
xmin = np.maximum(x1, bbox[0])
|
||||
ymin = np.maximum(y1, bbox[1])
|
||||
xmax = np.minimum(x2, bbox[2])
|
||||
ymax = np.minimum(y2, bbox[3])
|
||||
w = np.maximum(xmax - xmin, 0.)
|
||||
h = np.maximum(ymax - ymin, 0.)
|
||||
|
||||
# Volumes.
|
||||
inter_vol = h * w
|
||||
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
|
||||
jaccard = inter_vol / union_vol
|
||||
return np.squeeze(jaccard)
|
||||
|
||||
pre_scores = np.zeros((config.NUM_SSD_BOXES), dtype=np.float32)
|
||||
t_boxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
|
||||
t_label = np.zeros((config.NUM_SSD_BOXES), dtype=np.int64)
|
||||
for bbox in boxes:
|
||||
label = int(bbox[4])
|
||||
scores = jaccard_with_anchors(bbox)
|
||||
mask = (scores > matching_threshold)
|
||||
if not np.any(mask):
|
||||
mask[np.argmax(scores)] = True
|
||||
|
||||
mask = mask & (scores > pre_scores)
|
||||
pre_scores = np.maximum(pre_scores, scores)
|
||||
t_label = mask * label + (1 - mask) * t_label
|
||||
for i in range(4):
|
||||
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
|
||||
|
||||
index = np.nonzero(t_label)
|
||||
|
||||
# Transform to ltrb.
|
||||
bboxes = np.zeros((config.NUM_SSD_BOXES, 4), dtype=np.float32)
|
||||
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
|
||||
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
|
||||
|
||||
# Encode features.
|
||||
bboxes_t = bboxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.PRIOR_SCALING[0])
|
||||
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.PRIOR_SCALING[1]
|
||||
bboxes[index] = bboxes_t
|
||||
|
||||
num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||
return bboxes, t_label.astype(np.int32), num_match_num
|
||||
|
||||
def ssd_bboxes_decode(boxes, index):
|
||||
"""Decode predict boxes to [x, y, w, h]"""
|
||||
boxes_t = boxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
boxes_t[:, :2] = boxes_t[:, :2] * config.PRIOR_SCALING[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
|
||||
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.PRIOR_SCALING[1]) * default_boxes_t[:, 2:4]
|
||||
|
||||
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
|
||||
|
||||
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
|
||||
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
|
||||
|
||||
return bboxes
|
||||
|
||||
def preprocess_fn(image, box, is_training):
|
||||
"""Preprocess function for dataset."""
|
||||
|
||||
def _rand(a=0., b=1.):
|
||||
"""Generate random."""
|
||||
return np.random.rand() * (b - a) + a
|
||||
|
||||
def _infer_data(image, input_shape, box):
|
||||
img_h, img_w, _ = image.shape
|
||||
input_h, input_w = input_shape
|
||||
|
||||
scale = min(float(input_w) / float(img_w), float(input_h) / float(img_h))
|
||||
nw = int(img_w * scale)
|
||||
nh = int(img_h * scale)
|
||||
|
||||
image = cv2.resize(image, (nw, nh))
|
||||
|
||||
new_image = np.zeros((input_h, input_w, 3), np.float32)
|
||||
dh = (input_h - nh) // 2
|
||||
dw = (input_w - nw) // 2
|
||||
new_image[dh: (nh + dh), dw: (nw + dw), :] = image
|
||||
image = new_image
|
||||
|
||||
#When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
box = box.astype(np.float32)
|
||||
|
||||
box[:, [0, 2]] = (box[:, [0, 2]] * scale + dw) / input_w
|
||||
box[:, [1, 3]] = (box[:, [1, 3]] * scale + dh) / input_h
|
||||
return image, np.array((img_h, img_w), np.float32), box
|
||||
|
||||
def _data_aug(image, box, is_training, image_size=(300, 300)):
|
||||
"""Data augmentation function."""
|
||||
ih, iw, _ = image.shape
|
||||
w, h = image_size
|
||||
|
||||
if not is_training:
|
||||
return _infer_data(image, image_size, box)
|
||||
# Random settings
|
||||
scale_w = _rand(0.75, 1.25)
|
||||
scale_h = _rand(0.75, 1.25)
|
||||
|
||||
flip = _rand() < .5
|
||||
nw = iw * scale_w
|
||||
nh = ih * scale_h
|
||||
scale = min(w / nw, h / nh)
|
||||
nw = int(scale * nw)
|
||||
nh = int(scale * nh)
|
||||
|
||||
# Resize image
|
||||
image = cv2.resize(image, (nw, nh))
|
||||
|
||||
# place image
|
||||
new_image = np.zeros((h, w, 3), dtype=np.float32)
|
||||
dw = (w - nw) // 2
|
||||
dh = (h - nh) // 2
|
||||
new_image[dh:dh + nh, dw:dw + nw, :] = image
|
||||
image = new_image
|
||||
|
||||
# Flip image or not
|
||||
if flip:
|
||||
image = cv2.flip(image, 1, dst=None)
|
||||
|
||||
# Convert image to gray or not
|
||||
gray = _rand() < .25
|
||||
if gray:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# When the channels of image is 1
|
||||
if len(image.shape) == 2:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
image = np.concatenate([image, image, image], axis=-1)
|
||||
|
||||
box = box.astype(np.float32)
|
||||
|
||||
# Transform box with shape[x1, y1, x2, y2].
|
||||
box[:, [0, 2]] = (box[:, [0, 2]] * scale * scale_w + dw) / w
|
||||
box[:, [1, 3]] = (box[:, [1, 3]] * scale * scale_h + dh) / h
|
||||
|
||||
if flip:
|
||||
box[:, [0, 2]] = 1 - box[:, [2, 0]]
|
||||
|
||||
box, label, num_match_num = ssd_bboxes_encode(box)
|
||||
return image, box, label, num_match_num
|
||||
return _data_aug(image, box, is_training, image_size=config.IMG_SHAPE)
|
||||
|
||||
|
||||
def create_coco_label(is_training):
|
||||
"""Get image path and annotation from COCO."""
|
||||
from pycocotools.coco import COCO
|
||||
|
||||
coco_root = config.COCO_ROOT
|
||||
data_type = config.VAL_DATA_TYPE
|
||||
if is_training:
|
||||
data_type = config.TRAIN_DATA_TYPE
|
||||
|
||||
#Classes need to train or test.
|
||||
train_cls = config.COCO_CLASSES
|
||||
train_cls_dict = {}
|
||||
for i, cls in enumerate(train_cls):
|
||||
train_cls_dict[cls] = i
|
||||
|
||||
anno_json = os.path.join(coco_root, config.INSTANCES_SET.format(data_type))
|
||||
|
||||
coco = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco.loadCats(coco.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["id"]] = cat["name"]
|
||||
|
||||
image_ids = coco.getImgIds()
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
|
||||
for img_id in image_ids:
|
||||
image_info = coco.loadImgs(img_id)
|
||||
file_name = image_info[0]["file_name"]
|
||||
anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = coco.loadAnns(anno_ids)
|
||||
image_path = os.path.join(coco_root, data_type, file_name)
|
||||
annos = []
|
||||
for label in anno:
|
||||
bbox = label["bbox"]
|
||||
class_name = classs_dict[label["category_id"]]
|
||||
if class_name in train_cls:
|
||||
x_min, x_max = bbox[0], bbox[0] + bbox[2]
|
||||
y_min, y_max = bbox[1], bbox[1] + bbox[3]
|
||||
annos.append(list(map(round, [x_min, y_min, x_max, y_max])) + [train_cls_dict[class_name]])
|
||||
if len(annos) >= 1:
|
||||
image_files.append(image_path)
|
||||
image_anno_dict[image_path] = np.array(annos)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def anno_parser(annos_str):
|
||||
"""Parse annotation from string to list."""
|
||||
annos = []
|
||||
for anno_str in annos_str:
|
||||
anno = list(map(int, anno_str.strip().split(',')))
|
||||
annos.append(anno)
|
||||
return annos
|
||||
|
||||
|
||||
def filter_valid_data(image_dir, anno_path):
|
||||
"""Filter valid image file, which both in image_dir and anno_path."""
|
||||
image_files = []
|
||||
image_anno_dict = {}
|
||||
if not os.path.isdir(image_dir):
|
||||
raise RuntimeError("Path given is not valid.")
|
||||
if not os.path.isfile(anno_path):
|
||||
raise RuntimeError("Annotation file is not valid.")
|
||||
|
||||
with open(anno_path, "rb") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line_str = line.decode("utf-8").strip()
|
||||
line_split = str(line_str).split(' ')
|
||||
file_name = line_split[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
if os.path.isfile(image_path):
|
||||
image_anno_dict[image_path] = anno_parser(line_split[1:])
|
||||
image_files.append(image_path)
|
||||
return image_files, image_anno_dict
|
||||
|
||||
|
||||
def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.mindrecord", file_num=8):
|
||||
"""Create MindRecord file."""
|
||||
mindrecord_dir = config.MINDRECORD_DIR
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
if dataset == "coco":
|
||||
image_files, image_anno_dict = create_coco_label(is_training)
|
||||
else:
|
||||
image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
|
||||
|
||||
ssd_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "int32", "shape": [-1, 5]},
|
||||
}
|
||||
writer.add_schema(ssd_json, "ssd_json")
|
||||
|
||||
for image_name in image_files:
|
||||
with open(image_name, 'rb') as f:
|
||||
img = f.read()
|
||||
annos = np.array(image_anno_dict[image_name], dtype=np.int32)
|
||||
row = {"image": img, "annotation": annos}
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=4):
|
||||
"""Creatr SSD dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
|
||||
num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
decode = C.Decode()
|
||||
ds = ds.map(input_columns=["image"], operations=decode)
|
||||
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
|
||||
|
||||
if is_training:
|
||||
hwc_to_chw = C.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "box", "label", "num_match_num"],
|
||||
columns_order=["image", "box", "label", "num_match_num"],
|
||||
operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, python_multiprocessing=True,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
else:
|
||||
hwc_to_chw = C.HWC2CHW()
|
||||
ds = ds.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "image_shape", "annotation"],
|
||||
columns_order=["image", "image_shape", "annotation"],
|
||||
operations=compose_map_func)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -1,206 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from config import ConfigSSD
|
||||
from dataset import ssd_bboxes_decode
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
||||
"""Calculate iou of predicted bbox and ground truth."""
|
||||
bbox_pred = np.expand_dims(bbox_pred, axis=0)
|
||||
|
||||
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
|
||||
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
|
||||
pred_area = pred_w * pred_h
|
||||
|
||||
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
|
||||
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
|
||||
gt_area = gt_w * gt_h
|
||||
|
||||
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
|
||||
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
|
||||
|
||||
iw = np.maximum(iw, 0)
|
||||
ih = np.maximum(ih, 0)
|
||||
intersection_area = iw * ih
|
||||
|
||||
union_area = pred_area + gt_area - intersection_area
|
||||
union_area = np.maximum(union_area, np.finfo(float).eps)
|
||||
|
||||
iou = intersection_area * 1. / union_area
|
||||
return iou
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
x1 = all_boxes[:, 0]
|
||||
y1 = all_boxes[:, 1]
|
||||
x2 = all_boxes[:, 2]
|
||||
y2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def calc_ap(recall, precision):
|
||||
"""Calculate AP."""
|
||||
correct_recall = np.concatenate(([0.], recall, [1.]))
|
||||
correct_precision = np.concatenate(([0.], precision, [0.]))
|
||||
|
||||
for i in range(correct_recall.size - 1, 0, -1):
|
||||
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
|
||||
|
||||
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
|
||||
|
||||
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
|
||||
|
||||
return ap
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate mAP of predicted bboxes."""
|
||||
config = ConfigSSD()
|
||||
num_classes = config.NUM_CLASSES
|
||||
|
||||
all_detections = [None for i in range(num_classes)]
|
||||
all_pred_scores = [None for i in range(num_classes)]
|
||||
all_annotations = [None for i in range(num_classes)]
|
||||
average_precisions = {}
|
||||
num = [0 for i in range(num_classes)]
|
||||
accurate_num = [0 for i in range(num_classes)]
|
||||
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
boxes_scores = sample['box_scores']
|
||||
annotation = sample['annotation']
|
||||
|
||||
annotation = np.squeeze(annotation, axis=0)
|
||||
|
||||
pred_labels = np.argmax(boxes_scores, axis=-1)
|
||||
index = np.nonzero(pred_labels)
|
||||
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
|
||||
|
||||
pred_boxes = pred_boxes.clip(0, 1)
|
||||
boxes_scores = np.max(boxes_scores, axis=-1)
|
||||
boxes_scores = boxes_scores[index]
|
||||
pred_labels = pred_labels[index]
|
||||
|
||||
top_k = 50
|
||||
|
||||
for c in range(1, num_classes):
|
||||
if len(pred_labels) >= 1:
|
||||
class_box_scores = boxes_scores[pred_labels == c]
|
||||
class_boxes = pred_boxes[pred_labels == c]
|
||||
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
|
||||
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
|
||||
cmask = class_box_scores > 0.5
|
||||
class_boxes = class_boxes[cmask]
|
||||
class_box_scores = class_box_scores[cmask]
|
||||
|
||||
all_detections[c] = class_boxes
|
||||
all_pred_scores[c] = class_box_scores
|
||||
|
||||
for c in range(1, num_classes):
|
||||
if len(annotation) >= 1:
|
||||
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
|
||||
|
||||
for c in range(1, num_classes):
|
||||
false_positives = np.zeros((0,))
|
||||
true_positives = np.zeros((0,))
|
||||
scores = np.zeros((0,))
|
||||
num_annotations = 0.0
|
||||
|
||||
annotations = all_annotations[c]
|
||||
num_annotations += annotations.shape[0]
|
||||
detections = all_detections[c]
|
||||
pred_scores = all_pred_scores[c]
|
||||
|
||||
for index, detection in enumerate(detections):
|
||||
scores = np.append(scores, pred_scores[index])
|
||||
if len(annotations) >= 1:
|
||||
IoUs = calc_iou(detection, annotations)
|
||||
assigned_anno = np.argmax(IoUs)
|
||||
max_overlap = IoUs[assigned_anno]
|
||||
|
||||
if max_overlap >= 0.5:
|
||||
false_positives = np.append(false_positives, 0)
|
||||
true_positives = np.append(true_positives, 1)
|
||||
else:
|
||||
false_positives = np.append(false_positives, 1)
|
||||
true_positives = np.append(true_positives, 0)
|
||||
else:
|
||||
false_positives = np.append(false_positives, 1)
|
||||
true_positives = np.append(true_positives, 0)
|
||||
|
||||
if num_annotations == 0:
|
||||
if c not in average_precisions.keys():
|
||||
average_precisions[c] = 0
|
||||
continue
|
||||
accurate_num[c] = 1
|
||||
indices = np.argsort(-scores)
|
||||
false_positives = false_positives[indices]
|
||||
true_positives = true_positives[indices]
|
||||
|
||||
false_positives = np.cumsum(false_positives)
|
||||
true_positives = np.cumsum(true_positives)
|
||||
|
||||
recall = true_positives * 1. / num_annotations
|
||||
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
|
||||
|
||||
average_precision = calc_ap(recall, precision)
|
||||
|
||||
if c not in average_precisions.keys():
|
||||
average_precisions[c] = average_precision
|
||||
else:
|
||||
average_precisions[c] += average_precision
|
||||
|
||||
num[c] += 1
|
||||
|
||||
count = 0
|
||||
for key in average_precisions:
|
||||
if num[key] != 0:
|
||||
count += (average_precisions[key] / num[key])
|
||||
|
||||
mAP = count * 1. / accurate_num.count(1)
|
||||
return mAP
|
|
@ -1 +1 @@
|
|||
Subproject commit c27e428e9698dd4f9b198008596676bc2d1b49aa
|
||||
Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INCLUDE_MS_SESSION_H
|
||||
#define MINDSPORE_INCLUDE_MS_SESSION_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/ms_tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
class FuncGraph;
|
||||
namespace inference {
|
||||
class MS_API MSSession {
|
||||
public:
|
||||
MSSession() = default;
|
||||
|
||||
static std::shared_ptr<MSSession> CreateSession(const std::string &device, uint32_t device_id);
|
||||
|
||||
virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0;
|
||||
|
||||
virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
|
||||
void MS_API ExitInference();
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_MS_SESSION_H
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||
#define MINDSPORE_INCLUDE_MS_TENSOR_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
namespace inference {
|
||||
class MS_API MSTensor {
|
||||
public:
|
||||
MSTensor() = default;
|
||||
// brief Create a MSTensor pointer.
|
||||
//
|
||||
// param data_type DataTypeId of tensor to be created.
|
||||
// param shape Shape of tensor to be created.
|
||||
// return MSTensor pointer.
|
||||
static MSTensor *CreateTensor(TypeId data_type, const std::vector<int> &shape);
|
||||
|
||||
~MSTensor() = default;
|
||||
|
||||
virtual TypeId data_type() const = 0;
|
||||
|
||||
virtual TypeId set_data_type(const TypeId data_type) = 0;
|
||||
|
||||
virtual std::vector<int> shape() const = 0;
|
||||
|
||||
virtual size_t set_shape(const std::vector<int> &shape) = 0;
|
||||
|
||||
virtual int DimensionSize(size_t index) const = 0;
|
||||
// brief Get number of element in MSTensor.
|
||||
//
|
||||
// return Number of element in MSTensor.
|
||||
virtual int ElementsNum() const = 0;
|
||||
|
||||
virtual std::size_t hash() const = 0;
|
||||
// brief Get byte size of data in MSTensor.
|
||||
//
|
||||
// return Byte size of data in MSTensor.
|
||||
virtual size_t Size() const = 0;
|
||||
// brief Get pointer of data in MSTensor.
|
||||
//
|
||||
// The data pointer can be used to both write or read data in MSTensor.
|
||||
//
|
||||
// return A pointer points to data in MSTensor.
|
||||
virtual void *MutableData() const = 0;
|
||||
};
|
||||
using MultiTensor = std::vector<std::shared_ptr<inference::MSTensor>>;
|
||||
} // namespace inference
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_MS_TENSOR_H_
|
|
@ -35,3 +35,5 @@ from .logical_not import LogicalNot, gpu_schedule_LogicalNot
|
|||
from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd
|
||||
from .sub import Sub, gpu_schedule_Sub
|
||||
from .less_equal import LessEqual, gpu_schedule_LessEqual
|
||||
from .notequal import NotEqual, gpu_schedule_NotEqual
|
||||
from .greater_equal import GreaterEqual, gpu_schedule_GreaterEqual
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""greater_equal"""
|
||||
import _akg.tvm
|
||||
from _akg.ops.math import greater_equal
|
||||
from _akg.topi.generic import schedule_elemwise
|
||||
|
||||
def GreaterEqual(x, y):
|
||||
"""GreaterEqual."""
|
||||
return greater_equal.greater_equal(x, y)
|
||||
|
||||
|
||||
def gpu_schedule_GreaterEqual(outs):
|
||||
"""
|
||||
GPU schedule for GreaterEqual.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): Outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = _akg.tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with _akg.tvm.target.create(device):
|
||||
sch = schedule_elemwise(outs)
|
||||
return sch
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""notequal"""
|
||||
import _akg.tvm
|
||||
from _akg.ops.math import notequal
|
||||
from _akg.topi.generic import schedule_elemwise
|
||||
|
||||
def NotEqual(x, y):
|
||||
"""notequal."""
|
||||
return notequal.notequal(x, y)
|
||||
|
||||
|
||||
def gpu_schedule_NotEqual(outs):
|
||||
"""
|
||||
gpu schedule for NotEqual.
|
||||
|
||||
Args:
|
||||
outs (tvm.tensor.Tensor): outputs of compute.
|
||||
|
||||
Returns:
|
||||
sch (schedule.Schedule): The created schedule.
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = _akg.tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with _akg.tvm.target.create(device):
|
||||
sch = schedule_elemwise(outs)
|
||||
return sch
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""operator dsl function: greaterequal"""
|
||||
import _akg.tvm
|
||||
import _akg.topi
|
||||
from _akg.utils.dsl_create import produce_shapes
|
||||
from _akg.utils import validation_check as vc_util
|
||||
|
||||
|
||||
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
|
||||
def greater_equal(input1, input2):
|
||||
"""
|
||||
Check whether input1 greaterquals to input2.
|
||||
|
||||
Args:
|
||||
input1 (tvm.tensor.Tensor): Tensor.
|
||||
input2 (tvm.tensor.Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor. If input1 greaterquals to input2 return True, else return False.
|
||||
"""
|
||||
shape1 = [x.value for x in input1.shape]
|
||||
shape2 = [x.value for x in input2.shape]
|
||||
vc_util.check_shape(shape1)
|
||||
vc_util.check_shape(shape2)
|
||||
|
||||
shape1, shape2, shape = produce_shapes(shape1, shape2)
|
||||
|
||||
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
|
||||
dtype = input1.dtype
|
||||
|
||||
# get greaterquals compute
|
||||
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
|
||||
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
|
||||
|
||||
input1_bro = _akg.topi.broadcast_to(input1, shape)
|
||||
input2_bro = _akg.topi.broadcast_to(input2, shape)
|
||||
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] >= input2_bro[indice],
|
||||
t_value[indice], f_value[indice]), name="C")
|
||||
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
|
||||
|
||||
return res
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""operator dsl function: notequal"""
|
||||
import _akg.tvm
|
||||
import _akg.topi
|
||||
from _akg.utils.dsl_create import produce_shapes
|
||||
from _akg.utils import validation_check as vc_util
|
||||
|
||||
|
||||
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
|
||||
def notequal(input1, input2):
|
||||
"""
|
||||
check whether input1 notequals to input2.
|
||||
|
||||
Args:
|
||||
input1 (tvm.tensor.Tensor): Tensor.
|
||||
input2 (tvm.tensor.Tensor): Tensor.
|
||||
|
||||
Returns:
|
||||
tvm.tensor.Tensor. If input1 notequal to input2 return True, else return False.
|
||||
"""
|
||||
shape1 = [x.value for x in input1.shape]
|
||||
shape2 = [x.value for x in input2.shape]
|
||||
vc_util.check_shape(shape1)
|
||||
vc_util.check_shape(shape2)
|
||||
|
||||
shape1, shape2, shape = produce_shapes(shape1, shape2)
|
||||
|
||||
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
|
||||
dtype = input1.dtype
|
||||
|
||||
# get notequal compute
|
||||
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
|
||||
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
|
||||
|
||||
input1_bro = _akg.topi.broadcast_to(input1, shape)
|
||||
input2_bro = _akg.topi.broadcast_to(input2, shape)
|
||||
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] != input2_bro[indice],
|
||||
t_value[indice], f_value[indice]), name="C")
|
||||
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
|
||||
|
||||
return res
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""builtin_operations"""
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
@ -114,6 +113,24 @@ def bool_or(x, y):
|
|||
"""Implement `bool_or`."""
|
||||
return x or y
|
||||
|
||||
def vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
if obj_str == "shape":
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return fn
|
||||
if len(args) == 2:
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return Tensor(fn())
|
||||
if isinstance(args[0], Tensor):
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
|
||||
else:
|
||||
obj_str = "__r" + obj_str[2:]
|
||||
fn = getattr(args[1].asnumpy(), obj_str)
|
||||
y = args[0]
|
||||
return Tensor(np.array(fn(y)))
|
||||
|
||||
|
||||
def make_list(*xs):
|
||||
"""Implement `make_list`."""
|
||||
|
@ -124,17 +141,8 @@ def list_len(x):
|
|||
"""Implement `list_len`."""
|
||||
return len(x)
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def partial(*args):
|
||||
"""Implement `partial`."""
|
||||
func = args[0].__call__
|
||||
partial_func = functools.partial(func, *args[1:])
|
||||
return partial_func
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def depend(value, expr):
|
||||
def Depend(value, expr):
|
||||
"""Implement `Depend`."""
|
||||
return value
|
||||
|
||||
# only used in PyNative mode
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing akg compile with json"""
|
||||
import sys
|
||||
def run_compiler(op_json):
|
||||
"""
|
||||
Run AKG compiler to compile op with subprocess, if this process of
|
||||
compilation failed, an exception will be raised
|
||||
|
||||
Args:
|
||||
op_json (str): json string of the op
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
p = __import__("akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
res = func(op_json)
|
||||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_compiler(sys.argv[1])
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
|
||||
def _compile_akg_task(*json_strs):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
||||
"""
|
||||
akg_compiler = os.path.join(os.path.split(
|
||||
os.path.realpath(__file__))[0], "compiler.py")
|
||||
for json_str in json_strs:
|
||||
res = subprocess.run(
|
||||
[sys.executable, akg_compiler, json_str], text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Failed, args: {}!".format(json_str))
|
||||
|
||||
|
||||
def compile_akg_kernel_parallel(json_infos, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
json_infos: list. list contain kernel info(task id and json str)
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
|
||||
Returns:
|
||||
True for all compile success, False for some failed.
|
||||
"""
|
||||
if not isinstance(json_infos, list):
|
||||
raise ValueError("json_infos must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
if process == 0 and json_infos:
|
||||
process = 1
|
||||
|
||||
cpu_proc_num = cpu_count()
|
||||
max_proc_num = 16
|
||||
process = min([cpu_proc_num, max_proc_num, process])
|
||||
|
||||
args = [[] for _ in range(process)]
|
||||
for p, info in enumerate(json_infos):
|
||||
args[p % process].append(info)
|
||||
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compile_akg_task, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
|
@ -1,107 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Providing multi process compile with json"""
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def _compiletask(platform, *jsons):
|
||||
"""
|
||||
compile func called in single process
|
||||
|
||||
Parameters:
|
||||
platform: str. AKG platform or TBE platform
|
||||
*jsons: str. json str contain kernel info, suitable for json compile
|
||||
api
|
||||
|
||||
"""
|
||||
if platform == "AKG":
|
||||
p = __import__("_akg", globals(), locals(), ['ms'], 0)
|
||||
func = getattr(p.ms, "compilewithjson")
|
||||
for json_item in jsons:
|
||||
res = func(json_item)
|
||||
if not res:
|
||||
raise ValueError("Compile error")
|
||||
if platform == "TBE":
|
||||
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
|
||||
for json_item in jsons:
|
||||
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
|
||||
if res.returncode != 0:
|
||||
raise ValueError("Tbe compile error")
|
||||
|
||||
|
||||
def compilekernelparallel(jsons, process, waitime):
|
||||
"""
|
||||
compile kernel use multi processes
|
||||
|
||||
Parameters:
|
||||
jsons: list. json str list contain kernel info
|
||||
process: int. processes num
|
||||
waittime: int. max time the function blocked
|
||||
"""
|
||||
if not isinstance(jsons, list):
|
||||
raise ValueError("jsons must be a list")
|
||||
if not isinstance(process, int):
|
||||
raise ValueError("process must be a num")
|
||||
if not isinstance(waitime, int):
|
||||
raise ValueError("waittime must be a num")
|
||||
|
||||
jsons_akg = []
|
||||
jsons_tbe = []
|
||||
for json_ in jsons:
|
||||
j = json.loads(json_)
|
||||
if j["platform"] == "TBE":
|
||||
jsons_tbe.append(json_)
|
||||
continue
|
||||
if j["platform"] == "AKG":
|
||||
jsons_akg.append(json_)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"not support this platform {0}".format(j["platform"]))
|
||||
if jsons_akg:
|
||||
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
|
||||
else:
|
||||
process_akg = 0
|
||||
|
||||
if process_akg == 0 and jsons_akg:
|
||||
process_akg = 1
|
||||
process_tbe = process-process_akg
|
||||
if process_tbe == 0 and jsons_tbe:
|
||||
process_tbe = 1
|
||||
raise RuntimeWarning("we add a process for compile more operator")
|
||||
|
||||
args = [[] for _ in range(process_akg+process_tbe)]
|
||||
args_lens = len(args)
|
||||
for p in range(args_lens):
|
||||
if p < process_tbe:
|
||||
args[p].append("TBE")
|
||||
else:
|
||||
args[p].append("AKG")
|
||||
jsons_tbe_lens = len(jsons_tbe)
|
||||
for p in range(jsons_tbe_lens):
|
||||
args[p % process_tbe].append(jsons_tbe[p])
|
||||
jsons_akg_lens = len(jsons_akg)
|
||||
for p in range(jsons_akg_lens):
|
||||
args[process-p % process_akg-1].append(jsons_akg[p])
|
||||
for p in range(args_lens):
|
||||
args[p] = tuple(args[p])
|
||||
with Pool(processes=process) as pool:
|
||||
res = pool.starmap_async(_compiletask, args)
|
||||
res.get(timeout=waitime)
|
||||
return True
|
|
@ -15,13 +15,6 @@
|
|||
"""tbe common"""
|
||||
import json
|
||||
import os
|
||||
from attrdict import AttrDict
|
||||
|
||||
class ParamType(AttrDict):
|
||||
Required = "required"
|
||||
Dynamic = "dynamic"
|
||||
Optional = "optional"
|
||||
|
||||
|
||||
class TBEException(Exception):
|
||||
"""tbe exception class"""
|
||||
|
@ -112,7 +105,7 @@ def get_input_output(io_info, args):
|
|||
if len(item) > 1:
|
||||
arg.append(info)
|
||||
else:
|
||||
if info['param_type'] == ParamType.Dynamic:
|
||||
if info['param_type'] == 'dynamic':
|
||||
arg.append(info)
|
||||
args.append(arg)
|
||||
else:
|
||||
|
|
|
@ -28,7 +28,8 @@ build_in_impl_path = get_build_in_impl_path()
|
|||
# op function list
|
||||
op_build = "compile"
|
||||
op_pre_build = "pre_build"
|
||||
|
||||
fusion_pattern_start_flag = "fusion_pattern_start"
|
||||
fusion_pattern_end_flag = "fusion_pattern_end"
|
||||
|
||||
def _initialize(impl_path):
|
||||
"""Initialize"""
|
||||
|
@ -42,7 +43,6 @@ def _initialize(impl_path):
|
|||
|
||||
sys.path.insert(0, op_module_name)
|
||||
|
||||
|
||||
def build_op(build_type, json_str):
|
||||
"""
|
||||
call op functions with function name and input args json_str
|
||||
|
@ -108,7 +108,7 @@ def build_op(build_type, json_str):
|
|||
|
||||
# pre build
|
||||
if build_type == op_pre_build:
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name)
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
# disable only pattern configuration
|
||||
op_build_cfg_en()
|
||||
return get_op_pattern()
|
||||
|
@ -159,11 +159,14 @@ def compile_with_json(json_str):
|
|||
json_info = json.loads(json_str)
|
||||
if "fusion_op" in json_info:
|
||||
ret = compile_fusion_op(json_str)
|
||||
elif "compile_type" in json_info:
|
||||
ret = build_op(op_pre_build, json_str)
|
||||
else:
|
||||
ret = build_op(op_build, json_str)
|
||||
return ret
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
in_args = sys.stdin.readline()
|
||||
compile_with_json(in_args)
|
||||
result = compile_with_json(in_args)
|
||||
sys.stdout.write(fusion_pattern_start_flag + str(result) + fusion_pattern_end_flag)
|
||||
sys.stdout.flush()
|
||||
|
|
|
@ -75,7 +75,6 @@ def check_supported(op_json: str):
|
|||
|
||||
return ret
|
||||
|
||||
|
||||
def run_compiler(op_json):
|
||||
"""
|
||||
run compiler to compile op with subprocess
|
||||
|
@ -88,15 +87,16 @@ def run_compiler(op_json):
|
|||
"""
|
||||
try:
|
||||
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
|
||||
subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
|
||||
text=True, capture_output=True, check=True)
|
||||
return "Success", "Success"
|
||||
completed_object = subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300,
|
||||
text=True, capture_output=True, check=True)
|
||||
if completed_object:
|
||||
out = completed_object.stdout
|
||||
return "Success", out
|
||||
except subprocess.TimeoutExpired:
|
||||
tb = traceback.format_exc()
|
||||
return "TBEException", "CompileTimeOut: " + tb + "\ninput_args: " + op_json
|
||||
return "TBEException", "PreCompileTimeOut: " + tb + "\ninput_args: " + op_json
|
||||
except subprocess.CalledProcessError as e:
|
||||
return "TBEException", "CompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
|
||||
|
||||
return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
|
||||
|
||||
class CompilerPool:
|
||||
"""compiler pool"""
|
||||
|
@ -154,11 +154,11 @@ class CompilerPool:
|
|||
task_id, task_future = self.__running_tasks.pop(0)
|
||||
ret_type, result = task_future.get(330)
|
||||
if ret_type == "Success":
|
||||
ret = task_id, "Success"
|
||||
ret = task_id, "Success", result
|
||||
elif ret_type in ("Exception", "TBEException"):
|
||||
ret = task_id, ret_type + ":" + result
|
||||
ret = task_id, ret_type + ":" + result, "_"
|
||||
else:
|
||||
ret = task_id, "Exception: Not support return type:" + str(ret_type)
|
||||
ret = task_id, "Exception: Not support return type:" + str(ret_type), "_"
|
||||
return ret
|
||||
|
||||
def reset_task_info(self):
|
||||
|
|
|
@ -19,14 +19,15 @@ Interfaces for parser module in c++.
|
|||
from .parser import (Parser, create_obj_instance, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type,
|
||||
get_class_member_namespace_symbol, create_slice_obj,
|
||||
get_dataclass_attributes, get_dataclass_methods,
|
||||
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key,
|
||||
get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
|
||||
get_default_input, get_parse_method_of_class, get_scope_name,
|
||||
is_class_member, parse_cb, resolve_symbol)
|
||||
from .serialize import *
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
|
||||
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
|
||||
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
|
||||
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']
|
||||
'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member',
|
||||
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
|
||||
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
|
||||
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
|
||||
'create_slice_obj']
|
||||
|
|
|
@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype
|
|||
from mindspore.common.api import _MindSporeFunction
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from ..utils import Slice, Ellipsis_
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
|
@ -70,14 +69,9 @@ parse_expr_statement_white_list = (
|
|||
"append",
|
||||
)
|
||||
|
||||
def create_ellipsis_obj():
|
||||
"""Create Slice object"""
|
||||
return Ellipsis_()
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create Slice object"""
|
||||
return Slice(start, end, step)
|
||||
"""Create slice object"""
|
||||
return slice(start, end, step)
|
||||
|
||||
|
||||
def parse_cb(func, parse_method=None):
|
||||
|
@ -209,6 +203,14 @@ def get_object_key(obj):
|
|||
obj_id = instance_id + obj_id
|
||||
return obj_id, obj_key
|
||||
|
||||
def get_default_input(obj):
|
||||
if hasattr(obj, '__parameter__'):
|
||||
return obj.default_input
|
||||
if isinstance(obj, tuple):
|
||||
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
|
||||
args = tuple(convert(x) for x in obj)
|
||||
return args
|
||||
return obj
|
||||
|
||||
def is_class_member(node):
|
||||
"""Check the attr is class member variable."""
|
||||
|
@ -221,6 +223,9 @@ def is_class_member(node):
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_obj_id(obj):
|
||||
"""Get the obj id."""
|
||||
return str(id(obj))
|
||||
|
||||
def get_obj_type(obj):
|
||||
"""Get the obj type."""
|
||||
|
|
|
@ -126,7 +126,7 @@ convert_object_map = {
|
|||
T.make_list: F.make_list,
|
||||
T.make_slice: F.make_slice,
|
||||
T.range: F.make_range,
|
||||
|
||||
T.while_cond: M.while_cond,
|
||||
# lib function
|
||||
math.floor: NO_IMPLEMENT,
|
||||
math.trunc: NO_IMPLEMENT,
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
# ============================================================================
|
||||
"""standard_method"""
|
||||
from dataclasses import dataclass
|
||||
from mindspore.common import dtype as mstype
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.primitive import constexpr
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like
|
||||
from ...ops.composite.base import _append
|
||||
|
@ -102,11 +104,44 @@ def bool_(x):
|
|||
return x.__bool__()
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
"""return immedate x, x is a tensor of bool value"""
|
||||
def while_cond(x):
|
||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond:
|
||||
return F.cast(x, mstype.bool_)
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor_bool_cond(shp):
|
||||
"""check if tensor is a bool condition"""
|
||||
if shp in ((), (1,)):
|
||||
return True
|
||||
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
|
||||
|
||||
@constexpr
|
||||
def const_tensor_to_bool(x):
|
||||
"""convert bool tensor to bool condition"""
|
||||
if x is None:
|
||||
raise ValueError("Only constant tensor bool can be converted to bool")
|
||||
x = x.asnumpy()
|
||||
if x.shape not in ((), (1,)):
|
||||
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
|
||||
if x.shape == ():
|
||||
value = bool(x)
|
||||
else:
|
||||
value = bool(x[0])
|
||||
return value
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as conditon, if is constant, return immediate bool value"""
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond and F.isconstant(x):
|
||||
return const_tensor_to_bool(x)
|
||||
return F.cast(x, mstype.bool_)
|
||||
|
||||
|
||||
def and_(x, y):
|
||||
"""Implementation of `and` (`&`)."""
|
||||
return x.__and__(y)
|
||||
|
|
|
@ -91,3 +91,7 @@ def to_array(x): # pragma: no cover
|
|||
def not_contains(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
def while_cond(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
|
|
@ -19,7 +19,6 @@ import logging
|
|||
import os
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def cal_sha256(file_path):
|
||||
|
@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None):
|
|||
if fn is not None:
|
||||
return wrap_cell(fn)
|
||||
return wrap_cell
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slice:
|
||||
"""
|
||||
Slice class
|
||||
"""
|
||||
start: int
|
||||
end: int
|
||||
step: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ellipsis_:
|
||||
"""
|
||||
Ellipsis class
|
||||
"""
|
||||
|
|
|
@ -8,6 +8,10 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|||
add_compile_definitions(BUILDING_DLL)
|
||||
endif()
|
||||
|
||||
if (ENABLE_MPI)
|
||||
add_compile_definitions(ENABLE_MPI)
|
||||
endif ()
|
||||
|
||||
if(ENABLE_GPU)
|
||||
find_package(CUDA REQUIRED)
|
||||
find_package(Threads)
|
||||
|
@ -35,7 +39,7 @@ if(ENABLE_GPU)
|
|||
"device/gpu/*.cu"
|
||||
"kernel/gpu/*.cu"
|
||||
"kernel/akg/gpu/*.cc"
|
||||
"kernel/akg/akgkernelbuild.cc"
|
||||
"kernel/akg/akg_kernel_build.cc"
|
||||
"kernel/akg/akg_kernel_attrs_process.cc"
|
||||
)
|
||||
|
||||
|
@ -75,7 +79,9 @@ if (ENABLE_DUMP_PROTO)
|
|||
file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"utils/anf_ir.proto"
|
||||
"utils/summary.proto"
|
||||
"utils/lineage.proto"
|
||||
"utils/checkpoint.proto"
|
||||
"utils/print.proto"
|
||||
)
|
||||
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
|
||||
|
||||
|
@ -120,7 +126,11 @@ endforeach ()
|
|||
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
|
||||
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
|
||||
target_link_libraries(mindspore proto_input)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
if (ENABLE_CPU AND ENABLE_MPI)
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers mindspore::ompi)
|
||||
else ()
|
||||
target_link_libraries(mindspore securec mindspore::flatbuffers)
|
||||
endif ()
|
||||
if (NOT WIN32)
|
||||
target_link_libraries(mindspore dl)
|
||||
endif()
|
||||
|
@ -227,3 +237,29 @@ if (ENABLE_MINDDATA)
|
|||
add_subdirectory(mindrecord)
|
||||
add_subdirectory(dataset)
|
||||
endif ()
|
||||
|
||||
# build inference
|
||||
set(LOAD_ONNX_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_converter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc
|
||||
)
|
||||
add_library(inference SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc
|
||||
${LOAD_ONNX_SRC}
|
||||
)
|
||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf)
|
||||
|
||||
if (ENABLE_CPU)
|
||||
target_link_libraries(inference PRIVATE mindspore::dnnl mindspore::mkldnn)
|
||||
endif ()
|
||||
|
||||
if (USE_GLOG)
|
||||
target_link_libraries(inference PRIVATE mindspore::glog)
|
||||
else()
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
|
||||
endif ()
|
||||
endif()
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "common/trans.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include "./securec.h"
|
||||
#include "common/utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
@ -29,34 +27,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace trans {
|
||||
namespace {
|
||||
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> shape_4d(4, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
shape_4d[1] = shape[0];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[1] = shape[0];
|
||||
shape_4d[2] = shape[1];
|
||||
break;
|
||||
case 3:
|
||||
shape_4d[1] = shape[0];
|
||||
shape_4d[2] = shape[1];
|
||||
shape_4d[3] = shape[2];
|
||||
break;
|
||||
case 4:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
} // namespace
|
||||
const size_t kNchwDims = 4;
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
|
||||
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
|
||||
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
|
||||
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
|
||||
|
@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx,
|
|||
|
||||
template <typename T>
|
||||
T DivCeil(T n1, T n2) {
|
||||
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
|
||||
if (n2 != 0) {
|
||||
return (n1 - 1) / n2 + 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
enum DataTypeTransMode {
|
||||
|
@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|||
}
|
||||
|
||||
size_t ShapeSize(const std::vector<size_t> &shape) {
|
||||
size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>());
|
||||
return product;
|
||||
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
size_t TypeIdSize(const TypeId data_type) {
|
||||
|
@ -239,27 +212,177 @@ size_t TypeIdSize(const TypeId data_type) {
|
|||
return unsupported_type_error;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
if (shape.size() != kNchwDims) {
|
||||
MS_LOG(ERROR) << "Host shape dims shoud be 4";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kC]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kC]);
|
||||
device_shape.push_back(shape[kN]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
|
||||
const size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t c0 = 4;
|
||||
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
|
||||
auto no = DivCeil(shape.at(kN), kCubeSize);
|
||||
device_shape.push_back(first_dim);
|
||||
device_shape.push_back(no);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
const size_t C1 = 1;
|
||||
const size_t C0 = 4;
|
||||
device_shape.push_back(shape[kN]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[kH]);
|
||||
device_shape.push_back(shape[kW]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < kNdhwc) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
|
||||
std::vector<size_t> shape_4d(kNchwDims, 1);
|
||||
switch (shape.size()) {
|
||||
case 0:
|
||||
return shape_4d;
|
||||
case 1:
|
||||
shape_4d[kC] = shape[kN];
|
||||
break;
|
||||
case 2:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
break;
|
||||
case 3:
|
||||
shape_4d[kC] = shape[kN];
|
||||
shape_4d[kH] = shape[kC];
|
||||
shape_4d[kW] = shape[kH];
|
||||
break;
|
||||
case 4:
|
||||
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IsNeedPadding(const std::string &format, const size_t shape_size) {
|
||||
if (shape_size == 0) {
|
||||
return false;
|
||||
}
|
||||
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
|
||||
return false;
|
||||
} else if (shape_size < 4) {
|
||||
} else if (shape_size < kNchwDims) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int> GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> host_shape;
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
auto tensor = node_value->cast<tensor::TensorPtr>();
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert ";
|
||||
MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert ";
|
||||
}
|
||||
auto shape_temp = tensor->shape();
|
||||
(void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize);
|
||||
|
@ -280,134 +403,13 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
|
|||
if (padding_axis.empty() || shape.size() != padding_axis.size()) {
|
||||
return PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
std::vector<size_t> shape_4d(4, 1);
|
||||
std::vector<size_t> shape_4d(kNchwDims, 1);
|
||||
for (size_t index = 0; index < padding_axis.size(); index++) {
|
||||
shape_4d[padding_axis[index]] = shape[index];
|
||||
}
|
||||
return shape_4d;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(ERROR) << "Host shape dims shoud be 4";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[1]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[1]);
|
||||
device_shape.push_back(shape[0]);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
||||
device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize);
|
||||
device_shape.push_back(cout16 / kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
||||
size_t C0 = kCubeSize;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
device_shape.push_back((shape[1] - 1) / kCubeSize + 1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t c0 = 4;
|
||||
auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize);
|
||||
auto no = DivCeil(shape.at(0), kCubeSize);
|
||||
device_shape.push_back(first_dim);
|
||||
device_shape.push_back(no);
|
||||
device_shape.push_back(kCubeSize);
|
||||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
||||
if (!CheckDims(shape)) {
|
||||
MS_LOG(EXCEPTION) << "Check dims failed.";
|
||||
}
|
||||
std::vector<size_t> device_shape;
|
||||
size_t C1 = 1;
|
||||
size_t C0 = 4;
|
||||
device_shape.push_back(shape[0]);
|
||||
device_shape.push_back(C1);
|
||||
device_shape.push_back(shape[2]);
|
||||
device_shape.push_back(shape[3]);
|
||||
device_shape.push_back(C0);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < 5) {
|
||||
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
||||
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
||||
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
||||
|
@ -426,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
auto temp_shape = shape;
|
||||
std::vector<size_t> device_shape;
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
|
||||
// For [1] and [1024] shape we can trait it as NZ shape
|
||||
return shape;
|
||||
}
|
||||
if (shape.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
|
||||
} else {
|
||||
|
@ -439,7 +445,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|||
device_shape.push_back(kCubeSize);
|
||||
return device_shape;
|
||||
}
|
||||
if (shape.size() != 4) {
|
||||
if (shape.size() != kNchwDims) {
|
||||
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = PaddingShapeTo4dByDefault(shape);
|
||||
}
|
||||
|
@ -455,6 +461,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(size);
|
||||
MS_EXCEPTION_IF_NULL(total_size);
|
||||
*size = TypeIdSize(args.src_data_type);
|
||||
if (*size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
|
@ -540,10 +548,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
size_t n = args.host_shape[0];
|
||||
size_t c = args.host_shape[1];
|
||||
size_t h = args.host_shape[2];
|
||||
size_t w = args.host_shape[3];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
for (size_t ni = 0; ni < n; ni++) {
|
||||
for (size_t ci = 0; ci < c; ci++) {
|
||||
for (size_t hi = 0; hi < h; hi++) {
|
||||
|
@ -572,10 +580,10 @@ bool ToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
size_t n = args.host_shape[0];
|
||||
size_t c = args.host_shape[1];
|
||||
size_t h = args.host_shape[2];
|
||||
size_t w = args.host_shape[3];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
for (size_t ni = 0; ni < n; ni++) {
|
||||
for (size_t ci = 0; ci < c; ci++) {
|
||||
for (size_t hi = 0; hi < h; hi++) {
|
||||
|
@ -602,32 +610,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
|
||||
size_t c0 = CubeSizeByType(args.src_data_type);
|
||||
auto c0 = CubeSizeByType(args.src_data_type);
|
||||
if (c0 < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = DivCeil(c, c0);
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
size_t hwc0 = hw * c0;
|
||||
size_t nchw = n * chw;
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto hwc0 = hw * c0;
|
||||
auto nchw = n * chw;
|
||||
|
||||
size_t hf_cnt = DivCeil(n, kCubeSize);
|
||||
size_t vf_cnt = c1 * hw;
|
||||
size_t fractal_ele_cnt = c0 * kCubeSize;
|
||||
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
|
||||
size_t dst_size = total_ele_cnt * size;
|
||||
auto hf_cnt = DivCeil(n, kCubeSize);
|
||||
auto vf_cnt = c1 * hw;
|
||||
auto fractal_ele_cnt = c0 * kCubeSize;
|
||||
auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
|
||||
auto dst_size = total_ele_cnt * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size."
|
||||
<< "dst size is :" << dst_size << "device size is :" << args.device_size;
|
||||
|
@ -647,7 +655,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
auto src_ni = hfi * kCubeSize + col;
|
||||
auto src_idx = src_row_offset + chw * col;
|
||||
auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row;
|
||||
auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false;
|
||||
auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c;
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
|
@ -663,12 +671,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t total_size = ShapeSize(args.device_shape) * size;
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -677,18 +685,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
|
|||
auto n0 = args.device_shape.at(1);
|
||||
auto ni = args.device_shape.at(2);
|
||||
auto c0 = args.device_shape.at(3);
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
|
||||
size_t nc = ni * n0;
|
||||
size_t ncc0 = nc * c0;
|
||||
size_t wncc0 = w * ncc0;
|
||||
size_t hwncc0 = h * wncc0;
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto nc = ni * n0;
|
||||
auto ncc0 = nc * c0;
|
||||
auto wncc0 = w * ncc0;
|
||||
auto hwncc0 = h * wncc0;
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
|
||||
for (size_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
size_t n_head_addr = n_idx * chw;
|
||||
|
@ -720,20 +726,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
size_t cube = kCubeSize;
|
||||
size_t n = args.host_shape[0];
|
||||
size_t c = args.host_shape[1];
|
||||
size_t h = args.host_shape[2];
|
||||
size_t w = args.host_shape[3];
|
||||
|
||||
size_t c0 = 4;
|
||||
size_t c1 = DivCeil(c, c0);
|
||||
size_t hwc0 = h * w * c0;
|
||||
size_t hwc = h * w * c;
|
||||
size_t nhwc = n * h * w * c;
|
||||
|
||||
size_t n_cnt = DivCeil(n, cube);
|
||||
size_t v_cnt = DivCeil(h * w * c0 * c1, cube);
|
||||
auto cube = kCubeSize;
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const size_t c0 = 4;
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hwc0 = h * w * c0;
|
||||
auto hwc = h * w * c;
|
||||
auto nhwc = n * h * w * c;
|
||||
auto n_cnt = DivCeil(n, cube);
|
||||
auto v_cnt = DivCeil(h * w * c0 * c1, cube);
|
||||
size_t dst_idx = 0;
|
||||
|
||||
for (size_t vi = 0; vi < v_cnt; vi++) {
|
||||
|
@ -929,7 +933,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
|
@ -940,20 +944,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
size_t c0 = CubeSizeByType(args.src_data_type);
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto c0 = CubeSizeByType(args.src_data_type);
|
||||
if (c0 < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t c1 = DivCeil(c, c0);
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
size_t c1hwc0 = c1 * hw * c0;
|
||||
size_t wc0 = w * c0;
|
||||
if (args.device_format == kOpFormat_NC1HWC0_C04) {
|
||||
c0 = 4;
|
||||
}
|
||||
auto c1 = DivCeil(c, c0);
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto c1hwc0 = c1 * hw * c0;
|
||||
auto wc0 = w * c0;
|
||||
|
||||
for (size_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
size_t n_head_addr = n_idx * c1hwc0;
|
||||
|
@ -967,7 +974,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
size_t dst_idx = c0_idx + w_head_addr;
|
||||
size_t c_idx = c0_idx + c1_idx * c0;
|
||||
size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx;
|
||||
auto pad_zero = (c_idx < c) ? false : true;
|
||||
auto pad_zero = c_idx >= c;
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
|
@ -984,29 +991,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
size_t size = TypeIdSize(args.src_data_type);
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
size_t total_size = ShapeSize(args.device_shape) * size;
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
auto c1 = args.device_shape[1];
|
||||
auto c0 = args.device_shape[4];
|
||||
|
||||
size_t hw = h * w;
|
||||
size_t chw = c * hw;
|
||||
size_t wc0 = w * c0;
|
||||
size_t hwc0 = h * wc0;
|
||||
size_t c1hwc0 = c1 * hwc0;
|
||||
auto hw = h * w;
|
||||
auto chw = c * hw;
|
||||
auto wc0 = w * c0;
|
||||
auto hwc0 = h * wc0;
|
||||
auto c1hwc0 = c1 * hwc0;
|
||||
|
||||
for (size_t n_idx = 0; n_idx < n; n_idx++) {
|
||||
size_t n_head_addr = n_idx * chw;
|
||||
|
@ -1037,13 +1044,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const int co_idx = 4;
|
||||
const int c0_idx = 5;
|
||||
auto c1 = args.device_shape[0];
|
||||
auto co = args.device_shape[4];
|
||||
auto c0 = args.device_shape[5];
|
||||
auto co = args.device_shape[co_idx];
|
||||
auto c0 = args.device_shape[c0_idx];
|
||||
|
||||
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
|
@ -1055,7 +1064,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) {
|
|||
co_i * c0 + c0_i;
|
||||
size_t c_i = c0_i + c1_i * c0;
|
||||
size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i;
|
||||
auto pad_zero = (c_i < c && c0_i == co_i) ? false : true;
|
||||
auto pad_zero = !(c_i < c && c0_i == co_i);
|
||||
SetData(size, pad_zero, src_idx, dst_idx, args, result);
|
||||
}
|
||||
}
|
||||
|
@ -1076,12 +1085,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Check args failed.";
|
||||
return false;
|
||||
}
|
||||
auto n = args.host_shape[0];
|
||||
auto c = args.host_shape[1];
|
||||
auto h = args.host_shape[2];
|
||||
auto w = args.host_shape[3];
|
||||
auto co = args.device_shape[4];
|
||||
auto c0 = args.device_shape[5];
|
||||
auto n = args.host_shape[kN];
|
||||
auto c = args.host_shape[kC];
|
||||
auto h = args.host_shape[kH];
|
||||
auto w = args.host_shape[kW];
|
||||
const int co_idx = 4;
|
||||
const int c0_idx = 5;
|
||||
auto co = args.device_shape[co_idx];
|
||||
auto c0 = args.device_shape[c0_idx];
|
||||
for (size_t n_i = 0; n_i < n; n_i++) {
|
||||
for (size_t c_i = 0; c_i < c; c_i++) {
|
||||
for (size_t h_i = 0; h_i < h; h_i++) {
|
||||
|
|
|
@ -62,6 +62,7 @@ add_dependencies(engine-datasetops-source core)
|
|||
add_dependencies(engine-datasetops-source-sampler core)
|
||||
add_dependencies(engine-datasetops core)
|
||||
add_dependencies(engine-opt core)
|
||||
add_dependencies(engine-perf core)
|
||||
add_dependencies(engine-gnn core)
|
||||
add_dependencies(engine core)
|
||||
add_dependencies(text core)
|
||||
|
@ -81,6 +82,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||
$<TARGET_OBJECTS:engine-gnn>
|
||||
$<TARGET_OBJECTS:engine-perf>
|
||||
$<TARGET_OBJECTS:engine-datasetops>
|
||||
$<TARGET_OBJECTS:engine-opt>
|
||||
$<TARGET_OBJECTS:engine>
|
||||
|
@ -106,10 +108,11 @@ target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
|
|||
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
|
||||
else()
|
||||
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
|
||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
||||
endif()
|
||||
target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs
|
||||
mindspore::opencv_imgproc mindspore::tinyxml2)
|
||||
mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB})
|
||||
if (ENABLE_GPUQUE)
|
||||
target_link_libraries(_c_dataengine PRIVATE gpu_queue
|
||||
${CUDNN_PATH}/lib64/libcudnn.so
|
||||
|
|
|
@ -19,56 +19,64 @@
|
|||
#include <map>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "dataset/kernels/py_func_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/filter_op.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_shuffle.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/kernels/py_func_op.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "mindrecord/include/shard_category.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_shuffle.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
|
||||
|
||||
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp},
|
||||
{kShuffle, &DEPipeline::ParseShuffleOp},
|
||||
{kMindrecord, &DEPipeline::ParseMindRecordOp},
|
||||
{kMap, &DEPipeline::ParseMapOp},
|
||||
{kFilter, &DEPipeline::ParseFilterOp},
|
||||
{kBatch, &DEPipeline::ParseBatchOp},
|
||||
{kBarrier, &DEPipeline::ParseBarrierOp},
|
||||
{kRepeat, &DEPipeline::ParseRepeatOp},
|
||||
{kSkip, &DEPipeline::ParseSkipOp},
|
||||
{kZip, &DEPipeline::ParseZipOp},
|
||||
{kConcat, &DEPipeline::ParseConcatOp},
|
||||
{kRename, &DEPipeline::ParseRenameOp},
|
||||
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
|
||||
{kGenerator, &DEPipeline::ParseGeneratorOp},
|
||||
{kTfReader, &DEPipeline::ParseTFReaderOp},
|
||||
{kProject, &DEPipeline::ParseProjectOp},
|
||||
{kTake, &DEPipeline::ParseTakeOp},
|
||||
{kImageFolder, &DEPipeline::ParseImageFolderOp},
|
||||
{kMnist, &DEPipeline::ParseMnistOp},
|
||||
{kManifest, &DEPipeline::ParseManifestOp},
|
||||
{kVoc, &DEPipeline::ParseVOCOp},
|
||||
{kCifar10, &DEPipeline::ParseCifar10Op},
|
||||
{kCifar100, &DEPipeline::ParseCifar100Op},
|
||||
{kCelebA, &DEPipeline::ParseCelebAOp},
|
||||
{kRandomData, &DEPipeline::ParseRandomDataOp},
|
||||
{kTextFile, &DEPipeline::ParseTextFileOp}};
|
||||
static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
|
||||
{kShuffle, &DEPipeline::ParseShuffleOp},
|
||||
{kMindrecord, &DEPipeline::ParseMindRecordOp},
|
||||
{kMap, &DEPipeline::ParseMapOp},
|
||||
{kFilter, &DEPipeline::ParseFilterOp},
|
||||
{kBatch, &DEPipeline::ParseBatchOp},
|
||||
{kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp},
|
||||
{kBarrier, &DEPipeline::ParseBarrierOp},
|
||||
{kRepeat, &DEPipeline::ParseRepeatOp},
|
||||
{kSkip, &DEPipeline::ParseSkipOp},
|
||||
{kZip, &DEPipeline::ParseZipOp},
|
||||
{kConcat, &DEPipeline::ParseConcatOp},
|
||||
{kRename, &DEPipeline::ParseRenameOp},
|
||||
{kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
|
||||
{kGenerator, &DEPipeline::ParseGeneratorOp},
|
||||
{kTfReader, &DEPipeline::ParseTFReaderOp},
|
||||
{kProject, &DEPipeline::ParseProjectOp},
|
||||
{kTake, &DEPipeline::ParseTakeOp},
|
||||
{kImageFolder, &DEPipeline::ParseImageFolderOp},
|
||||
{kMnist, &DEPipeline::ParseMnistOp},
|
||||
{kManifest, &DEPipeline::ParseManifestOp},
|
||||
{kVoc, &DEPipeline::ParseVOCOp},
|
||||
{kCoco, &DEPipeline::ParseCocoOp},
|
||||
{kCifar10, &DEPipeline::ParseCifar10Op},
|
||||
{kCifar100, &DEPipeline::ParseCifar100Op},
|
||||
{kCelebA, &DEPipeline::ParseCelebAOp},
|
||||
{kRandomData, &DEPipeline::ParseRandomDataOp},
|
||||
{kTextFile, &DEPipeline::ParseTextFileOp},
|
||||
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
|
||||
{kClue, &DEPipeline::ParseClueOp}};
|
||||
|
||||
DEPipeline::DEPipeline() : iterator_(nullptr) {
|
||||
try {
|
||||
|
@ -292,70 +300,6 @@ Status DEPipeline::SetBatchParameters(const py::dict &args) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ValidateArgStorageOp(const py::dict &args) {
|
||||
// Required arguments
|
||||
if (((args.contains("dataset_files") && args["dataset_files"].is_none()) || args["schema"].is_none()) &&
|
||||
((args.contains("dataset_dir") && args["dataset_dir"].is_none()) ||
|
||||
(args["schema"].is_none() && args["schema_json_string"].is_none()))) {
|
||||
std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(ValidateArgStorageOp(args));
|
||||
std::shared_ptr<StorageOp::Builder> builder;
|
||||
if (args.contains("dataset_files") && !args["dataset_files"].is_none()) {
|
||||
builder = std::make_shared<StorageOp::Builder>();
|
||||
(void)builder->SetDatasetFileList(ToStringVector(args["dataset_files"]));
|
||||
(void)builder->SetSchemaFile(ToString(args["schema"]));
|
||||
} else if (args.contains("dataset_dir") && !args["dataset_dir"].is_none()) {
|
||||
builder = std::make_shared<StorageOp::Builder>();
|
||||
(void)builder->SetDatasetFilesDir(ToString(args["dataset_dir"]));
|
||||
if (!args["schema"].is_none()) {
|
||||
(void)builder->SetSchemaFile(ToString(args["schema"]));
|
||||
} else if (!args["schema_json_string"].is_none()) {
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
std::string s = ToString(args["schema_json_string"]);
|
||||
RETURN_IF_NOT_OK(schema->LoadSchemaString(s, std::vector<std::string>()));
|
||||
(void)builder->SetNumRows(schema->num_rows());
|
||||
(void)builder->SetSchema(std::move(schema));
|
||||
}
|
||||
}
|
||||
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "prefetch_size") {
|
||||
(void)builder->SetOpConnectorSize(ToInt(value));
|
||||
} else if (key == "columns_list") {
|
||||
(void)builder->SetColumnsToLoad(ToStringVector(value));
|
||||
} else if (key == "distribution") {
|
||||
(void)builder->SetDataDistributionFile(ToString(value));
|
||||
} else if (key == "labels_filename") {
|
||||
(void)builder->setLabelsFileName(ToString(value));
|
||||
} else if (key == "dataset_usage") {
|
||||
(void)builder->SetDatasetUsage(ToString(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
(void)builder->SetBatchSize(temp_batch_size_);
|
||||
(void)builder->SetDropRemainder(temp_drop_remainder_);
|
||||
|
||||
std::shared_ptr<StorageOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
num_rows_ = op->num_rows();
|
||||
num_classes_ = op->num_classes();
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
|
||||
if (!args["buffer_size"].is_none()) {
|
||||
|
@ -382,35 +326,27 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetO
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
|
||||
if (args["partitions"].is_none()) {
|
||||
std::string err_msg = "Error: partitions is not set (None)";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
|
||||
for (auto l : list) {
|
||||
if (!l.is_none()) {
|
||||
in_partitions->push_back(ToInt(l));
|
||||
Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded) {
|
||||
auto sampler = py::reinterpret_borrow<py::object>(handle);
|
||||
auto create = sampler.attr("create_for_minddataset");
|
||||
auto op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
|
||||
while (op != nullptr) {
|
||||
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
|
||||
if (sampler_op && num_padded > 0) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
stack_ops.push(sampler_op);
|
||||
} else {
|
||||
stack_ops.push(op);
|
||||
}
|
||||
op = op->GetChildOp();
|
||||
}
|
||||
|
||||
if (in_partitions->size() != 2) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
while (!stack_ops.empty()) {
|
||||
operators->push_back(stack_ops.top());
|
||||
stack_ops.pop();
|
||||
}
|
||||
|
||||
constexpr int kMaxPartitions = 64;
|
||||
if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
|
||||
std::string err_msg = "Error: partitions is invalid or not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -438,6 +374,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
|||
(void)builder->SetColumnsToLoad(in_col_names);
|
||||
}
|
||||
|
||||
if (!args["padded_sample"].is_none()) {
|
||||
(void)builder->SetPaddedSample(args["padded_sample"]);
|
||||
(void)builder->SetNumToPadSamples(ToInt(args["num_padded"]));
|
||||
}
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> operators;
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
|
@ -447,27 +387,16 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas
|
|||
(void)builder->SetNumMindRecordWorkers(ToInt(value));
|
||||
} else if (key == "block_reader" && ToBool(value) == true) {
|
||||
(void)builder->SetBlockReader();
|
||||
} else if (key == "global_shuffle" && ToBool(value) == true) {
|
||||
uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
|
||||
operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset");
|
||||
std::shared_ptr<mindrecord::ShardOperator> sample_op =
|
||||
create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
operators.push_back(sample_op);
|
||||
int num_padded = 0;
|
||||
if (!args["num_padded"].is_none()) {
|
||||
num_padded = ToInt(args["num_padded"]);
|
||||
}
|
||||
RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> in_partitions;
|
||||
if (!args["partitions"].is_none()) {
|
||||
auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
|
||||
if (Status::OK() != ret) {
|
||||
return ret;
|
||||
}
|
||||
operators.push_back(std::make_shared<mindrecord::ShardSample>(1, in_partitions[0], in_partitions[1]));
|
||||
}
|
||||
|
||||
if (!operators.empty()) {
|
||||
(void)builder->SetOperators(operators);
|
||||
}
|
||||
|
@ -493,6 +422,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
(void)builder->SetInColNames(in_col_names);
|
||||
} else if (key == "output_columns") {
|
||||
(void)builder->SetOutColNames(ToStringVector(value));
|
||||
} else if (key == "columns_order") {
|
||||
(void)builder->SetColOrder(ToStringVector(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "prefetch_size") {
|
||||
|
@ -642,18 +573,8 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
(void)builder->SetColumnsToMap(ToStringVector(value));
|
||||
}
|
||||
if (key == "pad_info") {
|
||||
std::map<std::string, std::pair<TensorShape, float>> pad_info;
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
py::tuple tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
float pad_val = tp[1].is_none() ? 0 : ToFloat(tp[1]);
|
||||
(void)pad_info.insert({ToString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info.insert({ToString(p.first), {TensorShape({}), 0}});
|
||||
}
|
||||
}
|
||||
PadInfo pad_info;
|
||||
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
|
||||
(void)builder->SetPaddingMap(pad_info, true);
|
||||
}
|
||||
}
|
||||
|
@ -665,6 +586,56 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::vector<std::string> mandatory_arguments = {"length_dependent_columns", "bucket_boundaries",
|
||||
"bucket_batch_sizes"};
|
||||
for (auto name : mandatory_arguments) {
|
||||
if (args[name.c_str()].is_none()) {
|
||||
std::string err_msg = "Error: " + name + " is not set.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<BucketBatchByLengthOp::Builder> builder = std::make_shared<BucketBatchByLengthOp::Builder>(
|
||||
ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]),
|
||||
ToIntVector(args[mandatory_arguments[2].c_str()]));
|
||||
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "length_dependent_columns") {
|
||||
(void)builder->SetLengthDependentColumns(ToStringVector(value));
|
||||
}
|
||||
if (key == "bucket_boundaries") {
|
||||
(void)builder->SetBucketBoundaries(ToIntVector(value));
|
||||
}
|
||||
if (key == "bucket_batch_sizes") {
|
||||
(void)builder->SetBucketBatchSizes(ToIntVector(value));
|
||||
}
|
||||
if (key == "element_length_function") {
|
||||
(void)builder->SetElementLengthFunction(value.cast<py::function>());
|
||||
}
|
||||
if (key == "pad_info") {
|
||||
PadInfo pad_info;
|
||||
RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info));
|
||||
(void)builder->SetPadInfo(pad_info);
|
||||
}
|
||||
if (key == "pad_to_bucket_boundary") {
|
||||
(void)builder->SetPadToBucketBoundary(ToBool(value));
|
||||
}
|
||||
if (key == "drop_remainder") {
|
||||
(void)builder->SetDropRemainder(ToBool(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<BucketBatchByLengthOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
|
||||
// Right now barrier should only take num_rows_per_buffer = 1
|
||||
|
@ -801,6 +772,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetColumnsToLoad(columns_to_load);
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
||||
schema_exists = true;
|
||||
} else if (key == "num_samples") {
|
||||
|
@ -856,9 +829,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -893,9 +864,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -922,6 +891,16 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["task"].is_none()) {
|
||||
std::string err_msg = "Error: No task specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["mode"].is_none()) {
|
||||
std::string err_msg = "Error: No mode specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
(void)builder->SetTask(ToString(args["task"]));
|
||||
|
@ -930,9 +909,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -951,6 +928,47 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
if (args["dataset_dir"].is_none()) {
|
||||
std::string err_msg = "Error: No dataset path specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["annotation_file"].is_none()) {
|
||||
std::string err_msg = "Error: No annotation_file specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
if (args["task"].is_none()) {
|
||||
std::string err_msg = "Error: No task specified";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
std::shared_ptr<CocoOp::Builder> builder = std::make_shared<CocoOp::Builder>();
|
||||
(void)builder->SetDir(ToString(args["dataset_dir"]));
|
||||
(void)builder->SetFile(ToString(args["annotation_file"]));
|
||||
(void)builder->SetTask(ToString(args["task"]));
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
|
||||
(void)builder->SetSampler(std::move(sampler));
|
||||
} else if (key == "decode") {
|
||||
(void)builder->SetDecode(ToBool(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<CocoOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
// Required arguments
|
||||
if (args["dataset_dir"].is_none()) {
|
||||
|
@ -966,9 +984,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1001,9 +1017,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1039,10 +1053,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
|
|||
(void)builder.SetNumWorkers(ToInt(value));
|
||||
} else if (key == "schema_file_path" || key == "schema_json_string") {
|
||||
schema_exists = true;
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder.SetTotalRows(ToInt(value));
|
||||
} else if (key == "columns_list") {
|
||||
columns_to_load = ToStringVector(value);
|
||||
} else if (key == "num_samples") {
|
||||
// This is not sampling here. The random data op needs to know how much data to
|
||||
// generate. It does not currently support sampling.
|
||||
(void)builder.SetTotalRows(ToInt(value));
|
||||
}
|
||||
}
|
||||
if (schema_exists) {
|
||||
|
@ -1077,9 +1093,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
|
|||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_parallel_workers") {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "sampler") {
|
||||
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
|
||||
|
@ -1121,8 +1135,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
|
|||
(void)builder->SetDecode(ToBool(value));
|
||||
} else if (key == "extensions") {
|
||||
(void)builder->SetExtensions(ToStringSet(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "dataset_type") {
|
||||
(void)builder->SetDatasetType(ToString(value));
|
||||
}
|
||||
|
@ -1152,8 +1164,10 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
(void)builder->SetTotalRows(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
(void)builder->SetNumDevices(ToInt(value));
|
||||
} else if (key == "shard_id") {
|
||||
|
@ -1166,5 +1180,106 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
|
|||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
auto tp = py::reinterpret_borrow<py::tuple>(p.second);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)");
|
||||
TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]);
|
||||
std::shared_ptr<Tensor> pad_val = nullptr;
|
||||
if (py::isinstance<py::str>(tp[1])) {
|
||||
std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
Tensor::CreateTensor(&pad_val, std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar()),
|
||||
"Cannot create pad_value Tensor");
|
||||
} else {
|
||||
float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(),
|
||||
DataType(DataType::DE_FLOAT32)),
|
||||
"Cannot create pad_value Tensor");
|
||||
pad_val->SetItemAt<float>({}, pad_val_float);
|
||||
}
|
||||
(void)pad_info->insert({ToString(p.first), {shape, pad_val}});
|
||||
} else { // tuple is None
|
||||
(void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}});
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<BuildVocabOp::Builder> builder = std::make_shared<BuildVocabOp::Builder>();
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "freq_range") {
|
||||
py::tuple tp = py::reinterpret_borrow<py::tuple>(value);
|
||||
if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow<py::int_>(tp[0]));
|
||||
if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow<py::int_>(tp[1]));
|
||||
} else if (key == "top_k") {
|
||||
builder->SetTopK(py::reinterpret_borrow<py::int_>(value));
|
||||
} else if (key == "columns") {
|
||||
(void)builder->SetColumnNames(ToStringVector(value));
|
||||
} else if (key == "vocab") {
|
||||
(void)builder->SetVocab(value.cast<std::shared_ptr<Vocab>>());
|
||||
} else if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "special_first") {
|
||||
(void)builder->SetSpecialFirst(ToBool(value));
|
||||
} else if (key == "special_tokens") {
|
||||
(void)builder->SetSpecialTokens(ToStringVector(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<BuildVocabOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
|
||||
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
|
||||
if (!args["dataset_files"].is_none()) {
|
||||
(void)builder->SetClueFilesList(ToStringVector(args["dataset_files"]));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
|
||||
}
|
||||
// Optional arguments
|
||||
for (auto arg : args) {
|
||||
std::string key = py::str(arg.first);
|
||||
py::handle value = arg.second;
|
||||
if (!value.is_none()) {
|
||||
if (key == "num_parallel_workers") {
|
||||
(void)builder->SetNumWorkers(ToInt(value));
|
||||
} else if (key == "shuffle_files") {
|
||||
(void)builder->SetShuffleFiles(ToBool(value));
|
||||
} else if (key == "shuffle_global") {
|
||||
(void)builder->SetShuffleGlobal(ToBool(value));
|
||||
} else if (key == "num_samples") {
|
||||
(void)builder->SetNumSamples(ToInt(value));
|
||||
} else if (key == "num_shards") {
|
||||
(void)builder->SetNumDevices(ToInt(value));
|
||||
} else if (key == "shard_id") {
|
||||
(void)builder->SetDeviceId(ToInt(value));
|
||||
} else if (key == "cols_to_keyword") {
|
||||
std::map<std::string, std::string> map_dict;
|
||||
for (auto p : py::reinterpret_borrow<py::dict>(value)) {
|
||||
if (!p.second.is_none()) {
|
||||
map_dict.insert({ToString(p.first), ToString(p.second)});
|
||||
} else {
|
||||
map_dict.insert({ToString(p.first), ToString(p.first)});
|
||||
}
|
||||
}
|
||||
(void)builder->SetColsKeyMap(map_dict);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::shared_ptr<ClueOp> op;
|
||||
RETURN_IF_NOT_OK(builder->Build(&op));
|
||||
*ptr = op;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
@ -36,10 +37,10 @@ using DsOpPtr = std::shared_ptr<DatasetOp>;
|
|||
|
||||
// enum for the dataset operator names
|
||||
enum OpName {
|
||||
kStorage = 0,
|
||||
kShuffle,
|
||||
kMindrecord,
|
||||
kBatch,
|
||||
kBucketBatch,
|
||||
kBarrier,
|
||||
kCache,
|
||||
kRepeat,
|
||||
|
@ -58,11 +59,14 @@ enum OpName {
|
|||
kMnist,
|
||||
kManifest,
|
||||
kVoc,
|
||||
kCoco,
|
||||
kCifar10,
|
||||
kCifar100,
|
||||
kCelebA,
|
||||
kRandomData,
|
||||
kTextFile
|
||||
kTextFile,
|
||||
kBuildVocab,
|
||||
kClue
|
||||
};
|
||||
|
||||
// The C++ binder class that we expose to the python script.
|
||||
|
@ -100,14 +104,14 @@ class DEPipeline {
|
|||
|
||||
int GetRepeatCount() const;
|
||||
|
||||
Status ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *ptr);
|
||||
|
||||
Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status BuildMindrecordSamplerChain(const py::handle &handle,
|
||||
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators,
|
||||
int num_padded);
|
||||
|
||||
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
@ -118,6 +122,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
@ -142,6 +148,8 @@ class DEPipeline {
|
|||
|
||||
Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
@ -160,14 +168,17 @@ class DEPipeline {
|
|||
|
||||
Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
|
||||
|
||||
private:
|
||||
// Execution tree that links the dataset operators.
|
||||
std::shared_ptr<ExecutionTree> tree_;
|
||||
|
||||
std::unique_ptr<DatasetIterator> iterator_;
|
||||
|
||||
// Validate required args passed to storage op.
|
||||
Status ValidateArgStorageOp(const py::dict &args);
|
||||
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
|
||||
|
||||
int batch_size_;
|
||||
int repeat_num_;
|
||||
|
|
|
@ -16,8 +16,37 @@
|
|||
#include <exception>
|
||||
|
||||
#include "dataset/api/de_pipeline.h"
|
||||
#include "dataset/kernels/no_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/clue_op.h"
|
||||
#include "dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/kernels/data/concatenate_op.h"
|
||||
#include "dataset/kernels/data/duplicate_op.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
#include "dataset/kernels/data/one_hot_op.h"
|
||||
#include "dataset/kernels/data/pad_end_op.h"
|
||||
#include "dataset/kernels/data/slice_op.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/kernels/image/bounding_box_augment_op.h"
|
||||
#include "dataset/kernels/image/center_crop_op.h"
|
||||
#include "dataset/kernels/image/cut_out_op.h"
|
||||
#include "dataset/kernels/image/decode_op.h"
|
||||
|
@ -26,51 +55,51 @@
|
|||
#include "dataset/kernels/image/normalize_op.h"
|
||||
#include "dataset/kernels/image/pad_op.h"
|
||||
#include "dataset/kernels/image/random_color_adjust_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
#include "dataset/kernels/image/random_crop_op.h"
|
||||
#include "dataset/kernels/image/random_crop_with_bbox_op.h"
|
||||
#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h"
|
||||
#include "dataset/kernels/image/random_horizontal_flip_op.h"
|
||||
#include "dataset/kernels/image/random_resize_op.h"
|
||||
#include "dataset/kernels/image/random_rotation_op.h"
|
||||
#include "dataset/kernels/image/random_vertical_flip_op.h"
|
||||
#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
|
||||
#include "dataset/kernels/image/rescale_op.h"
|
||||
#include "dataset/kernels/image/resize_bilinear_op.h"
|
||||
#include "dataset/kernels/image/resize_op.h"
|
||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/datasetops/source/io_block.h"
|
||||
#include "dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/kernels/no_op.h"
|
||||
#include "dataset/text/kernels/jieba_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "dataset/text/vocab.h"
|
||||
#include "dataset/text/kernels/lookup_op.h"
|
||||
#include "dataset/text/kernels/ngram_op.h"
|
||||
#include "dataset/text/kernels/to_number_op.h"
|
||||
#include "dataset/text/kernels/unicode_char_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/wordpiece_tokenizer_op.h"
|
||||
#include "dataset/text/vocab.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "mindrecord/include/shard_distributed_sample.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
#include "mindrecord/include/shard_pk_sample.h"
|
||||
#include "mindrecord/include/shard_sample.h"
|
||||
#include "mindrecord/include/shard_sequential_sample.h"
|
||||
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#ifdef ENABLE_ICU4C
|
||||
#include "dataset/text/kernels/basic_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/bert_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/case_fold_op.h"
|
||||
#include "dataset/text/kernels/normalize_utf8_op.h"
|
||||
#include "dataset/text/kernels/regex_replace_op.h"
|
||||
#include "dataset/text/kernels/regex_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/unicode_script_tokenizer_op.h"
|
||||
#include "dataset/text/kernels/whitespace_tokenizer_op.h"
|
||||
#endif
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -143,51 +172,49 @@ void bindDatasetOps(py::module *m) {
|
|||
});
|
||||
|
||||
(void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
|
||||
.def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
|
||||
THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
|
||||
.def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
|
||||
.def_static("get_num_rows_and_classes", [](const std::string &path) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(
|
||||
ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
|
||||
THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
});
|
||||
|
||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||
auto create = sampler.attr("_create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
|
||||
return count;
|
||||
});
|
||||
.def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler,
|
||||
const int64_t num_padded) {
|
||||
int64_t count = 0;
|
||||
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||
if (py::hasattr(sampler, "create_for_minddataset")) {
|
||||
auto create = sampler.attr("create_for_minddataset");
|
||||
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||
}
|
||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
|
||||
.def_static("get_num_rows_and_classes",
|
||||
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
||||
[](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
int64_t count = 0, num_classes = 0;
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
|
||||
THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
|
||||
return py::make_tuple(count, num_classes);
|
||||
})
|
||||
.def_static("get_class_indexing",
|
||||
[](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
.def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
|
||||
(void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
|
||||
.def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
|
||||
.def_static("get_num_rows", [](const std::string &dir) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
|
||||
THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
|
@ -201,20 +228,44 @@ void bindDatasetOps(py::module *m) {
|
|||
THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
|
||||
.def_static("get_num_rows", [](const py::list &files) {
|
||||
int64_t count = 0;
|
||||
std::vector<std::string> filenames;
|
||||
for (auto file : files) {
|
||||
file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
|
||||
}
|
||||
THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
|
||||
return count;
|
||||
});
|
||||
|
||||
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
|
||||
return count;
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
||||
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
|
||||
const std::string &task_mode, const py::dict &dict) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
|
||||
THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
});
|
||||
(void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
|
||||
.def_static("get_class_indexing",
|
||||
[](const std::string &dir, const std::string &file, const std::string &task) {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
||||
THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
|
||||
return output_class_indexing;
|
||||
})
|
||||
.def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
|
||||
return count;
|
||||
});
|
||||
}
|
||||
void bindTensor(py::module *m) {
|
||||
(void)py::class_<GlobalContext>(*m, "GlobalContext")
|
||||
|
@ -227,12 +278,14 @@ void bindTensor(py::module *m) {
|
|||
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
|
||||
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
|
||||
.def("set_seed", &ConfigManager::set_seed)
|
||||
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
|
||||
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
|
||||
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
|
||||
.def("get_worker_connector_size", &ConfigManager::worker_connector_size)
|
||||
.def("get_op_connector_size", &ConfigManager::op_connector_size)
|
||||
.def("get_seed", &ConfigManager::seed)
|
||||
.def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
|
||||
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
|
||||
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||
|
||||
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
|
||||
.def(py::init([](py::array arr) {
|
||||
|
@ -300,6 +353,11 @@ void bindTensorOps1(py::module *m) {
|
|||
.def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
|
||||
py::arg("NumOps") = UniformAugOp::kDefNumOps);
|
||||
|
||||
(void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>(
|
||||
*m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
|
||||
.def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"),
|
||||
py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio);
|
||||
|
||||
(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
|
||||
*m, "ResizeBilinearOp",
|
||||
"Tensor operation to resize an image using "
|
||||
|
@ -314,6 +372,11 @@ void bindTensorOps1(py::module *m) {
|
|||
(void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
|
||||
*m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
|
||||
.def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
|
||||
|
||||
(void)py::class_<RandomHorizontalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomHorizontalFlipWithBBoxOp>>(
|
||||
*m, "RandomHorizontalFlipWithBBoxOp",
|
||||
"Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.")
|
||||
.def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability);
|
||||
}
|
||||
|
||||
void bindTensorOps2(py::module *m) {
|
||||
|
@ -321,6 +384,12 @@ void bindTensorOps2(py::module *m) {
|
|||
*m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
|
||||
.def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
|
||||
|
||||
(void)py::class_<RandomVerticalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomVerticalFlipWithBBoxOp>>(
|
||||
*m, "RandomVerticalFlipWithBBoxOp",
|
||||
"Tensor operation to randomly flip an image vertically"
|
||||
" and adjust bounding boxes.")
|
||||
.def(py::init<float>(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability);
|
||||
|
||||
(void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
|
||||
"Gives random crop of specified size "
|
||||
"Takes crop size")
|
||||
|
@ -332,10 +401,84 @@ void bindTensorOps2(py::module *m) {
|
|||
py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
|
||||
(void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
|
||||
|
||||
(void)py::class_<RandomCropWithBBoxOp, TensorOp, std::shared_ptr<RandomCropWithBBoxOp>>(*m, "RandomCropWithBBoxOp",
|
||||
"Gives random crop of given "
|
||||
"size + adjusts bboxes "
|
||||
"Takes crop size")
|
||||
.def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
|
||||
py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop,
|
||||
py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom,
|
||||
py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft,
|
||||
py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight,
|
||||
py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType,
|
||||
py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded,
|
||||
py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG,
|
||||
py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB);
|
||||
|
||||
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
|
||||
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
|
||||
.def(py::init<int32_t>());
|
||||
|
||||
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
|
||||
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
|
||||
.def(py::init<std::shared_ptr<Tensor>>());
|
||||
|
||||
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor slice operation.")
|
||||
.def(py::init<bool>())
|
||||
.def(py::init([](const py::list &py_list) {
|
||||
std::vector<dsize_t> c_list;
|
||||
for (auto l : py_list) {
|
||||
if (!l.is_none()) {
|
||||
c_list.push_back(py::reinterpret_borrow<py::int_>(l));
|
||||
}
|
||||
}
|
||||
return std::make_shared<SliceOp>(c_list);
|
||||
}))
|
||||
.def(py::init([](const py::tuple &py_slice) {
|
||||
if (py_slice.size() != 3) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
}
|
||||
Slice c_slice;
|
||||
if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]),
|
||||
py::reinterpret_borrow<py::int_>(py_slice[2]));
|
||||
} else if (py_slice[0].is_none() && py_slice[2].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
|
||||
} else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
|
||||
c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]));
|
||||
}
|
||||
|
||||
if (!c_slice.valid()) {
|
||||
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
|
||||
}
|
||||
return std::make_shared<SliceOp>(c_slice);
|
||||
}));
|
||||
|
||||
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
|
||||
.value("EQ", RelationalOp::kEqual)
|
||||
.value("NE", RelationalOp::kNotEqual)
|
||||
.value("LT", RelationalOp::kLess)
|
||||
.value("LE", RelationalOp::kLessEqual)
|
||||
.value("GT", RelationalOp::kGreater)
|
||||
.value("GE", RelationalOp::kGreaterEqual)
|
||||
.export_values();
|
||||
|
||||
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
|
||||
"Tensor mask operation using relational comparator")
|
||||
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
|
||||
|
||||
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp", "Duplicate tensor.")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
|
||||
*m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
|
||||
.def(py::init<int64_t>());
|
||||
|
||||
(void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp",
|
||||
"Tensor operation concatenate tensors.")
|
||||
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
|
||||
py::arg("prepend").none(true), py::arg("append").none(true));
|
||||
|
||||
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
|
||||
*m, "RandomRotationOp",
|
||||
"Tensor operation to apply RandomRotation."
|
||||
|
@ -347,6 +490,10 @@ void bindTensorOps2(py::module *m) {
|
|||
py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
|
||||
py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
|
||||
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
|
||||
|
||||
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
|
||||
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
|
||||
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
|
||||
}
|
||||
|
||||
void bindTensorOps3(py::module *m) {
|
||||
|
@ -364,6 +511,20 @@ void bindTensorOps3(py::module *m) {
|
|||
py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
|
||||
py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
|
||||
|
||||
(void)py::class_<RandomCropAndResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomCropAndResizeWithBBoxOp>>(
|
||||
*m, "RandomCropAndResizeWithBBoxOp",
|
||||
"Tensor operation to randomly crop an image (with BBoxes) and resize to a given size."
|
||||
"Takes output height and width and"
|
||||
"optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
|
||||
"interpolation mode, and max attempts to crop")
|
||||
.def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
|
||||
py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb,
|
||||
py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb,
|
||||
py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb,
|
||||
py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb,
|
||||
py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation,
|
||||
py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter);
|
||||
|
||||
(void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
|
||||
*m, "RandomColorAdjustOp",
|
||||
"Tensor operation to adjust an image's color randomly."
|
||||
|
@ -418,9 +579,13 @@ void bindTensorOps4(py::module *m) {
|
|||
.def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
|
||||
py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
|
||||
py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
|
||||
(void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
|
||||
"TensorOp to convert strings to numbers.")
|
||||
.def(py::init<DataType>(), py::arg("data_type"))
|
||||
.def(py::init<std::string>(), py::arg("data_type"));
|
||||
}
|
||||
|
||||
void bindTensorOps5(py::module *m) {
|
||||
void bindTokenizerOps(py::module *m) {
|
||||
(void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
|
||||
.def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
|
||||
py::arg("mode") = JiebaMode::kMix)
|
||||
|
@ -433,6 +598,60 @@ void bindTensorOps5(py::module *m) {
|
|||
"Tensor operation to LookUp each word")
|
||||
.def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
|
||||
.def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
|
||||
(void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping")
|
||||
.def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
|
||||
const std::string &>(),
|
||||
py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
|
||||
py::arg("separator"));
|
||||
(void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
|
||||
*m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
|
||||
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &>(),
|
||||
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
|
||||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken));
|
||||
}
|
||||
|
||||
void bindDependIcuTokenizerOps(py::module *m) {
|
||||
#ifdef ENABLE_ICU4C
|
||||
(void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
|
||||
*m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
|
||||
.def(py::init<>());
|
||||
(void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
|
||||
*m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
|
||||
.def(py::init<>())
|
||||
.def(py::init<bool>(), py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace);
|
||||
(void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
|
||||
*m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
|
||||
.def(py::init<>());
|
||||
(void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
|
||||
*m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
|
||||
.def(py::init<>())
|
||||
.def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
|
||||
(void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
|
||||
*m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
|
||||
.def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"), py::arg("replace"),
|
||||
py::arg("replace_all"));
|
||||
(void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
|
||||
*m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
|
||||
.def(py::init<const std::string &, const std::string &>(), py::arg("delim_pattern"), py::arg("keep_delim_pattern"));
|
||||
(void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
|
||||
*m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
|
||||
.def(py::init<bool, bool, NormalizeForm, bool>(), py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
|
||||
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
|
||||
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken);
|
||||
(void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(*m, "BertTokenizerOp",
|
||||
"Tokenizer used for Bert text process.")
|
||||
.def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, bool, bool,
|
||||
NormalizeForm, bool>(),
|
||||
py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
|
||||
py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
|
||||
py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
|
||||
py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
|
||||
py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
|
||||
py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
|
||||
py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken);
|
||||
#endif
|
||||
}
|
||||
|
||||
void bindSamplerOps(py::module *m) {
|
||||
|
@ -449,32 +668,29 @@ void bindSamplerOps(py::module *m) {
|
|||
.def("add_child",
|
||||
[](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
|
||||
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
|
||||
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
|
||||
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
|
||||
|
||||
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
|
||||
py::arg("seed"));
|
||||
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
|
||||
|
||||
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
|
||||
.def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
|
||||
.def(py::init<int64_t, int64_t, bool>());
|
||||
|
||||
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
|
||||
.def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
|
||||
py::arg("num_samples"))
|
||||
.def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
|
||||
.def(py::init<int64_t, bool, bool>());
|
||||
|
||||
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
|
||||
.def(py::init<>());
|
||||
|
||||
(void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
|
||||
.def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
|
||||
.def(py::init<int64_t, int64_t>());
|
||||
|
||||
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>>(), py::arg("indices"));
|
||||
.def(py::init<int64_t, std::vector<int64_t>>());
|
||||
|
||||
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
||||
*m, "MindrecordSubsetRandomSampler")
|
||||
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||
|
||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||
*m, "MindrecordPkSampler")
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||
|
@ -486,12 +702,27 @@ void bindSamplerOps(py::module *m) {
|
|||
}
|
||||
}));
|
||||
|
||||
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
|
||||
.def(py::init<int64_t, int64_t, bool, uint32_t>());
|
||||
|
||||
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
|
||||
*m, "MindrecordRandomSampler")
|
||||
.def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
|
||||
return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
|
||||
}));
|
||||
|
||||
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
|
||||
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
|
||||
.def(py::init([](int num_samples, int start_index) {
|
||||
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
|
||||
}));
|
||||
|
||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
||||
py::arg("replacement"));
|
||||
.def(py::init<int64_t, std::vector<double>, bool>());
|
||||
|
||||
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
|
||||
.def(py::init<py::object>(), py::arg("pySampler"));
|
||||
.def(py::init<int64_t, py::object>());
|
||||
}
|
||||
|
||||
void bindInfoObjects(py::module *m) {
|
||||
|
@ -503,16 +734,18 @@ void bindInfoObjects(py::module *m) {
|
|||
|
||||
void bindVocabObjects(py::module *m) {
|
||||
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
|
||||
.def(py::init<>())
|
||||
.def_static("from_list",
|
||||
[](const py::list &words) {
|
||||
[](const py::list &words, const py::list &special_tokens, bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_file",
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size) {
|
||||
[](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
|
||||
bool special_first) {
|
||||
std::shared_ptr<Vocab> v;
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
|
||||
THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
|
||||
return v;
|
||||
})
|
||||
.def_static("from_dict", [](const py::dict &words) {
|
||||
|
@ -529,10 +762,22 @@ void bindGraphData(py::module *m) {
|
|||
THROW_IF_ERROR(g_out->Init());
|
||||
return g_out;
|
||||
}))
|
||||
.def("get_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
|
||||
.def("get_all_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
|
||||
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_edges",
|
||||
[](gnn::Graph &g, gnn::EdgeType edge_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_nodes_from_edges",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
|
@ -541,12 +786,38 @@ void bindGraphData(py::module *m) {
|
|||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
||||
std::vector<gnn::NodeType> neighbor_types) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_neg_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
||||
gnn::NodeType neg_neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_node_feature",
|
||||
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||
return out.getRow();
|
||||
})
|
||||
.def("graph_info",
|
||||
[](gnn::Graph &g) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||
return out;
|
||||
});
|
||||
})
|
||||
.def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
||||
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
||||
// This is where we externalize the C logic as python modules
|
||||
|
@ -555,9 +826,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
(void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
|
||||
|
||||
(void)py::enum_<OpName>(m, "OpName", py::arithmetic())
|
||||
.value("STORAGE", OpName::kStorage)
|
||||
.value("SHUFFLE", OpName::kShuffle)
|
||||
.value("BATCH", OpName::kBatch)
|
||||
.value("BUCKETBATCH", OpName::kBucketBatch)
|
||||
.value("BARRIER", OpName::kBarrier)
|
||||
.value("MINDRECORD", OpName::kMindrecord)
|
||||
.value("CACHE", OpName::kCache)
|
||||
|
@ -578,11 +849,14 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("MNIST", OpName::kMnist)
|
||||
.value("MANIFEST", OpName::kManifest)
|
||||
.value("VOC", OpName::kVoc)
|
||||
.value("COCO", OpName::kCoco)
|
||||
.value("CIFAR10", OpName::kCifar10)
|
||||
.value("CIFAR100", OpName::kCifar100)
|
||||
.value("RANDOMDATA", OpName::kRandomData)
|
||||
.value("BUILDVOCAB", OpName::kBuildVocab)
|
||||
.value("CELEBA", OpName::kCelebA)
|
||||
.value("TEXTFILE", OpName::kTextFile);
|
||||
.value("TEXTFILE", OpName::kTextFile)
|
||||
.value("CLUE", OpName::kClue);
|
||||
|
||||
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
|
||||
.value("DE_JIEBA_MIX", JiebaMode::kMix)
|
||||
|
@ -590,6 +864,16 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
.value("DE_JIEBA_HMM", JiebaMode::kHmm)
|
||||
.export_values();
|
||||
|
||||
#ifdef ENABLE_ICU4C
|
||||
(void)py::enum_<NormalizeForm>(m, "NormalizeForm", py::arithmetic())
|
||||
.value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
|
||||
.value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
|
||||
.value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
|
||||
.value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
|
||||
.value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
|
||||
.export_values();
|
||||
#endif
|
||||
|
||||
(void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
|
||||
.value("DE_INTER_LINEAR", InterpolationMode::kLinear)
|
||||
.value("DE_INTER_CUBIC", InterpolationMode::kCubic)
|
||||
|
@ -609,12 +893,13 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
bindTensorOps2(&m);
|
||||
bindTensorOps3(&m);
|
||||
bindTensorOps4(&m);
|
||||
bindTensorOps5(&m);
|
||||
bindTokenizerOps(&m);
|
||||
bindSamplerOps(&m);
|
||||
bindDatasetOps(&m);
|
||||
bindInfoObjects(&m);
|
||||
bindVocabObjects(&m);
|
||||
bindGraphData(&m);
|
||||
bindDependIcuTokenizerOps(&m);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -11,6 +11,7 @@ add_library(core OBJECT
|
|||
data_type.cc
|
||||
global_context.cc
|
||||
tensor.cc
|
||||
tensor_row.cc
|
||||
tensor_shape.cc
|
||||
)
|
||||
add_dependencies(core mindspore::protobuf)
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "dataset/engine/dataset_iterator.h"
|
||||
#include "dataset/engine/datasetops/barrier_op.h"
|
||||
#include "dataset/engine/datasetops/batch_op.h"
|
||||
#include "dataset/engine/datasetops/build_vocab_op.h"
|
||||
#include "dataset/engine/datasetops/dataset_op.h"
|
||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
|
@ -38,7 +39,6 @@
|
|||
#include "dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "dataset/engine/datasetops/source/storage_op.h"
|
||||
#include "dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "dataset/engine/datasetops/take_op.h"
|
||||
#include "dataset/engine/datasetops/zip_op.h"
|
||||
|
|
|
@ -48,7 +48,7 @@ Status ConfigManager::FromJson(const nlohmann::json &j) {
|
|||
Status ConfigManager::LoadFile(const std::string &settingsFile) {
|
||||
Status rc;
|
||||
if (!Path(settingsFile).Exists()) {
|
||||
RETURN_STATUS_UNEXPECTED("File is not found");
|
||||
RETURN_STATUS_UNEXPECTED("File is not found.");
|
||||
}
|
||||
// Some settings are mandatory, others are not (with default). If a setting
|
||||
// is optional it will set a default value if the config is missing from the file.
|
||||
|
@ -59,14 +59,11 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) {
|
|||
rc = FromJson(js);
|
||||
} catch (const nlohmann::json::type_error &e) {
|
||||
std::ostringstream ss;
|
||||
ss << "Client settings failed to load:\n" << e.what();
|
||||
ss << "Client file failed to load:\n" << e.what();
|
||||
std::string err_msg = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} catch (const std::exception &err) {
|
||||
std::ostringstream ss;
|
||||
ss << "Client settings failed to load:\n" << err.what();
|
||||
std::string err_msg = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
RETURN_STATUS_UNEXPECTED("Client file failed to load.");
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
@ -88,5 +85,7 @@ void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector
|
|||
uint32_t ConfigManager::seed() const { return seed_; }
|
||||
|
||||
void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
|
||||
|
||||
void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,12 +111,21 @@ class ConfigManager {
|
|||
// @param seed - The default seed to use
|
||||
void set_seed(uint32_t seed);
|
||||
|
||||
// setter function
|
||||
// @param interval - The setting to apply to the config
|
||||
void set_monitor_sampling_interval(uint32_t interval);
|
||||
|
||||
// getter function
|
||||
// @return The iterval of monitor sampling
|
||||
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
|
||||
|
||||
private:
|
||||
int32_t rows_per_buffer_{kCfgRowsPerBuffer};
|
||||
int32_t num_parallel_workers_{kCfgParallelWorkers};
|
||||
int32_t worker_connector_size_{kCfgWorkerConnectorSize};
|
||||
int32_t op_connector_size_{kCfgOpConnectorSize};
|
||||
uint32_t seed_{kCfgDefaultSeed};
|
||||
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
|
||||
|
||||
// Private helper function that taks a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
|
|
|
@ -47,9 +47,13 @@ constexpr uint32_t kCfgParallelWorkers = 4;
|
|||
constexpr uint32_t kCfgWorkerConnectorSize = 16;
|
||||
constexpr uint32_t kCfgOpConnectorSize = 16;
|
||||
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
|
||||
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
|
||||
|
||||
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
|
||||
constexpr uint8_t kCVInvalidType = 255;
|
||||
|
||||
using connection_id_type = int64_t;
|
||||
using row_id_type = int64_t;
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ DataType DataType::FromNpArray(const py::array &arr) {
|
|||
return DataType(DataType::DE_FLOAT32);
|
||||
} else if (py::isinstance<py::array_t<std::double_t>>(arr)) {
|
||||
return DataType(DataType::DE_FLOAT64);
|
||||
} else if (arr.dtype().kind() == 'S') {
|
||||
} else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') {
|
||||
return DataType(DataType::DE_STRING);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!";
|
||||
|
|
|
@ -128,7 +128,9 @@ class DataType {
|
|||
// @tparam T
|
||||
// @return true or false
|
||||
template <typename T>
|
||||
bool IsCompatible() const;
|
||||
bool IsCompatible() const {
|
||||
return type_ == FromCType<T>();
|
||||
}
|
||||
|
||||
// returns true if the template type is the same as the Tensor type_
|
||||
// @tparam T
|
||||
|
@ -146,6 +148,9 @@ class DataType {
|
|||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static DataType FromCType();
|
||||
|
||||
// Convert from DataType to Pybind type
|
||||
// @return
|
||||
py::dtype AsNumpyType() const;
|
||||
|
@ -191,68 +196,68 @@ class DataType {
|
|||
};
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<bool>() const {
|
||||
return type_ == DataType::DE_BOOL;
|
||||
inline DataType DataType::FromCType<bool>() {
|
||||
return DataType(DataType::DE_BOOL);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<double>() const {
|
||||
return type_ == DataType::DE_FLOAT64;
|
||||
inline DataType DataType::FromCType<double>() {
|
||||
return DataType(DataType::DE_FLOAT64);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<float>() const {
|
||||
return type_ == DataType::DE_FLOAT32;
|
||||
inline DataType DataType::FromCType<float>() {
|
||||
return DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<float16>() const {
|
||||
return type_ == DataType::DE_FLOAT16;
|
||||
inline DataType DataType::FromCType<float16>() {
|
||||
return DataType(DataType::DE_FLOAT16);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<int64_t>() const {
|
||||
return type_ == DataType::DE_INT64;
|
||||
inline DataType DataType::FromCType<int64_t>() {
|
||||
return DataType(DataType::DE_INT64);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<uint64_t>() const {
|
||||
return type_ == DataType::DE_UINT64;
|
||||
inline DataType DataType::FromCType<uint64_t>() {
|
||||
return DataType(DataType::DE_UINT64);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<int32_t>() const {
|
||||
return type_ == DataType::DE_INT32;
|
||||
inline DataType DataType::FromCType<int32_t>() {
|
||||
return DataType(DataType::DE_INT32);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<uint32_t>() const {
|
||||
return type_ == DataType::DE_UINT32;
|
||||
inline DataType DataType::FromCType<uint32_t>() {
|
||||
return DataType(DataType::DE_UINT32);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<int16_t>() const {
|
||||
return type_ == DataType::DE_INT16;
|
||||
inline DataType DataType::FromCType<int16_t>() {
|
||||
return DataType(DataType::DE_INT16);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<uint16_t>() const {
|
||||
return type_ == DataType::DE_UINT16;
|
||||
inline DataType DataType::FromCType<uint16_t>() {
|
||||
return DataType(DataType::DE_UINT16);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<int8_t>() const {
|
||||
return type_ == DataType::DE_INT8;
|
||||
inline DataType DataType::FromCType<int8_t>() {
|
||||
return DataType(DataType::DE_INT8);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<uint8_t>() const {
|
||||
return type_ == DataType::DE_UINT8;
|
||||
inline DataType DataType::FromCType<uint8_t>() {
|
||||
return DataType(DataType::DE_UINT8);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool DataType::IsCompatible<std::string_view>() const {
|
||||
return type_ == DataType::DE_STRING;
|
||||
inline DataType DataType::FromCType<std::string_view>() {
|
||||
return DataType(DataType::DE_STRING);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
@ -229,7 +230,12 @@ Status Tensor::CreateTensorFromNumpyString(std::shared_ptr<Tensor> *ptr, py::arr
|
|||
}
|
||||
arr.resize({arr.size()}); // flatten the py::array so we can iterate once
|
||||
std::vector<std::string> strings;
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
|
||||
|
||||
if (arr.dtype().kind() == 'U') {
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::str>(s)); });
|
||||
} else {
|
||||
std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast<py::bytes>(s)); });
|
||||
}
|
||||
|
||||
arr.resize(shape); // resize arr back to the original shape
|
||||
|
||||
|
@ -306,6 +312,50 @@ Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::Byte
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const std::string &file_path) {
|
||||
std::ifstream fs;
|
||||
fs.open(file_path, std::ios::binary | std::ios::in);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path);
|
||||
int64_t num_bytes = fs.seekg(0, std::ios::end).tellg();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file");
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8)));
|
||||
int64_t written_bytes = fs.read(reinterpret_cast<char *>((*ptr)->GetMutableBuffer()), num_bytes).gcount();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor");
|
||||
fs.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tensor::CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
|
||||
const TensorShape &shape, const DataType &type, dsize_t pad_size) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type));
|
||||
|
||||
unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer();
|
||||
int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size;
|
||||
|
||||
for (int i = 0; i < bytes_list.value_size(); i++) {
|
||||
// read string data into tensor
|
||||
const std::string ¤t_element = bytes_list.value(i);
|
||||
int return_code =
|
||||
memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size());
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor");
|
||||
|
||||
current_tensor_addr += current_element.size();
|
||||
tensor_bytes_remaining -= current_element.size();
|
||||
|
||||
// pad
|
||||
int64_t chars_to_pad = pad_size - current_element.size();
|
||||
return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast<int>(' '), chars_to_pad);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor");
|
||||
|
||||
current_tensor_addr += chars_to_pad;
|
||||
tensor_bytes_remaining -= chars_to_pad;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Memcpy the given strided array's used part to consecutive memory
|
||||
// Consider a 3-d array
|
||||
// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]]
|
||||
|
@ -539,11 +589,13 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_
|
|||
if (type() == DataType::DE_STRING) {
|
||||
RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet.");
|
||||
}
|
||||
|
||||
dsize_t flat_ind;
|
||||
std::vector<dsize_t> t_shape = shape().AsVector();
|
||||
std::vector<dsize_t> r(t_shape.begin() + ind.size(), t_shape.end());
|
||||
*remaining = TensorShape(r);
|
||||
ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0);
|
||||
|
||||
RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind));
|
||||
// check if GetBuffer() returns null, we should flag this as an error, this sanity check will only
|
||||
// be true is the tensor failed to allocate memory.
|
||||
|
@ -584,6 +636,39 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p
|
|||
}
|
||||
}
|
||||
|
||||
Status Tensor::Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &tensor) {
|
||||
std::string err_msg;
|
||||
err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : "";
|
||||
err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : "";
|
||||
err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : "";
|
||||
|
||||
err_msg +=
|
||||
(index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : "";
|
||||
err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : "";
|
||||
uchar *start_addr_of_ind = nullptr;
|
||||
|
||||
TensorShape remaining_shape = tensor->shape();
|
||||
StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape);
|
||||
err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : "";
|
||||
|
||||
if (!err_msg.empty()) {
|
||||
MS_LOG(DEBUG) << "Insert tensor message: " << err_msg;
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
int ret_code =
|
||||
memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes());
|
||||
|
||||
if (ret_code == 0) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
err_msg += "[Tensor] error in memcpy_s when inserting tensor\n";
|
||||
MS_LOG(DEBUG) << "Tensor message: " << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Tensor::ExpandDim(const dsize_t &axis) {
|
||||
if (axis > Rank()) {
|
||||
std::string err = "Axis is out of bound";
|
||||
|
@ -649,7 +734,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
|
|||
Status Tensor::GetItemAt(std::string_view *o, const std::vector<dsize_t> &index) const {
|
||||
RETURN_UNEXPECTED_IF_NULL(data_);
|
||||
RETURN_UNEXPECTED_IF_NULL(o);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string");
|
||||
|
||||
uchar *start = nullptr;
|
||||
offset_t length = 0;
|
||||
|
@ -699,6 +784,8 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
|
|||
for (; itr != end<std::string_view>(); itr++) {
|
||||
max = std::max((*itr).length(), max);
|
||||
}
|
||||
// if all strings are empty, numpy stores a byte for each string |S1
|
||||
max = (max == 0 ? 1 : max);
|
||||
uint64_t total_size = shape_.NumOfElements() * max;
|
||||
char *tmp_data = reinterpret_cast<char *>(data_allocator_->allocate(total_size));
|
||||
if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array.");
|
||||
|
@ -708,8 +795,10 @@ Status Tensor::GetDataAsNumpyStrings(py::array *data) {
|
|||
itr = begin<std::string_view>();
|
||||
uint64_t i = 0;
|
||||
for (; itr != end<std::string_view>(); itr++, i++) {
|
||||
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
|
||||
if (!(*itr).empty()) {
|
||||
ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data.");
|
||||
}
|
||||
}
|
||||
auto strides = shape_.Strides();
|
||||
std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; });
|
||||
|
@ -847,6 +936,78 @@ Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length
|
|||
*length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length
|
||||
return Status::OK();
|
||||
}
|
||||
Status Tensor::CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0");
|
||||
|
||||
uint8_t type_size = type_.SizeInBytes();
|
||||
size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size;
|
||||
dsize_t src_flat_ind = 0, dst_flat_ind = 0;
|
||||
RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind));
|
||||
RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind));
|
||||
|
||||
const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size;
|
||||
unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error");
|
||||
return Status::OK();
|
||||
}
|
||||
Status Tensor::Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty.");
|
||||
if (type_.IsNumeric()) {
|
||||
return SliceNumeric(out, indices);
|
||||
} else {
|
||||
return SliceString(out, indices);
|
||||
}
|
||||
}
|
||||
Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
|
||||
RETURN_IF_NOT_OK(
|
||||
CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(indices.size())}), type_));
|
||||
(*out)->GetMutableBuffer();
|
||||
dsize_t out_index = 0;
|
||||
dsize_t dim_length = shape_[0];
|
||||
dsize_t type_size = type_.SizeInBytes();
|
||||
dsize_t src_start = HandleNeg(indices[0], dim_length);
|
||||
uchar *dst_addr = (*out)->data_;
|
||||
dsize_t count = 1;
|
||||
|
||||
for (dsize_t i = 0; i < indices.size(); i++) {
|
||||
dsize_t cur_index = HandleNeg(indices[i], dim_length);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
cur_index >= 0 && cur_index < dim_length,
|
||||
"Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")");
|
||||
if (i < indices.size() - 1) {
|
||||
dsize_t next_index = HandleNeg(indices[i + 1], dim_length);
|
||||
if (next_index == cur_index + 1) {
|
||||
count++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size,
|
||||
count * type_size);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric");
|
||||
out_index += count;
|
||||
if (i < indices.size() - 1) {
|
||||
src_start = HandleNeg(indices[i + 1], dim_length); // next index
|
||||
}
|
||||
count = 1;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices) {
|
||||
dsize_t dim_length = shape_[0];
|
||||
std::vector<std::string> strings;
|
||||
for (dsize_t index : indices) {
|
||||
dsize_t cur_index = HandleNeg(index, dim_length);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
cur_index >= 0 && cur_index < dim_length,
|
||||
"Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")");
|
||||
std::string_view sv;
|
||||
GetItemAt(&sv, {cur_index});
|
||||
strings.emplace_back(sv);
|
||||
}
|
||||
return CreateTensor(out, strings);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,9 +44,6 @@ class Tensor;
|
|||
|
||||
using CharAllocPtr = std::unique_ptr<Allocator<unsigned char>>;
|
||||
using TensorAllocPtr = std::shared_ptr<Allocator<Tensor>>; // An allocator shared_ptr for Tensors
|
||||
using TensorRow = std::vector<std::shared_ptr<Tensor>>; // A row is a set of Tensor pointers
|
||||
using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows
|
||||
using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality
|
||||
|
||||
class Tensor {
|
||||
public:
|
||||
|
@ -118,6 +115,16 @@ class Tensor {
|
|||
static Status CreateTensor(std::shared_ptr<Tensor> *, TensorImpl tensor_impl, const TensorShape &shape, DataType type,
|
||||
const unsigned char *data = nullptr);
|
||||
|
||||
/// Create a copy of the input tensor
|
||||
/// \param out [out] output tensor to be generated
|
||||
/// \param in [in] orginal tensor to be copied
|
||||
/// \return Status
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *out, const std::shared_ptr<Tensor> &in) {
|
||||
const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator();
|
||||
*out = std::allocate_shared<Tensor>(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A static factory method to create a Tensor from a given py::array.
|
||||
// @param ptr output argument to hold the created Tensor
|
||||
// @param arr py::array
|
||||
|
@ -135,9 +142,41 @@ class Tensor {
|
|||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::vector<std::string> &strings,
|
||||
const TensorShape &shape = TensorShape::CreateUnknownRankShape());
|
||||
|
||||
// create tensor from protobuf bytelist with strings
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
|
||||
const TensorShape &shape);
|
||||
|
||||
// A static factory method to create a Tensor from a given list of numbers.
|
||||
// @param ptr output argument to hold the created Tensor
|
||||
// @param items elements of the tensor
|
||||
// @param shape shape of the tensor
|
||||
// @return Status Code
|
||||
template <typename T>
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::vector<T> &items,
|
||||
const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) {
|
||||
DataType type = DataType::FromCType<T>();
|
||||
auto items_ptr = reinterpret_cast<const uchar *>(&items[0]);
|
||||
TensorShape shape = shape_req;
|
||||
if (!shape.known()) {
|
||||
shape = TensorShape({static_cast<dsize_t>(items.size())});
|
||||
}
|
||||
return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr);
|
||||
}
|
||||
|
||||
// A static factory method to create a Tensor from a given number.
|
||||
// @param ptr output argument to hold the created Tensor
|
||||
// @param item value
|
||||
// @return Status Code
|
||||
template <typename T>
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const T &item) {
|
||||
return CreateTensor<T>(ptr, {item}, TensorShape::CreateScalar());
|
||||
}
|
||||
// Create tensor from protobuf bytelist with uint8 or int8 types
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
|
||||
const TensorShape &shape, const DataType &type, dsize_t pad_size);
|
||||
|
||||
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::string &path);
|
||||
|
||||
// Copy raw data of a array based on shape and strides to the destination pointer
|
||||
// @param dst Pointer to the destination array where the content is to be copied
|
||||
// @param src Pointer to the source of strided array to be copied
|
||||
|
@ -260,11 +299,6 @@ class Tensor {
|
|||
// @return const unsigned char*
|
||||
const unsigned char *GetBuffer() const;
|
||||
|
||||
// Get the starting memory address for the data of the tensor. This potentially
|
||||
// drives an allocation if the data area.
|
||||
// @return unsigned char*
|
||||
unsigned char *GetMutableBuffer();
|
||||
|
||||
// Getter of the type
|
||||
// @return
|
||||
DataType type() const { return type_; }
|
||||
|
@ -323,6 +357,22 @@ class Tensor {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
// Handle negative indices.
|
||||
static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
|
||||
|
||||
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
|
||||
// Based on the type of tensor, SliceNumeric or SliceString will be called
|
||||
// @param out Tensor
|
||||
// @param indices vector of indices
|
||||
// @return Status error code
|
||||
Status Slice(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
|
||||
|
||||
// Slice numeric tensors.
|
||||
Status SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
|
||||
|
||||
// Slice string tensors
|
||||
Status SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize_t> &indices);
|
||||
|
||||
// Constructs numpy array from input tensor
|
||||
// @param data this data is the location of python data
|
||||
// @return Status code
|
||||
|
@ -332,6 +382,9 @@ class Tensor {
|
|||
|
||||
static Status GetBufferInfo(Tensor &t, py::buffer_info *out);
|
||||
|
||||
// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
|
||||
Status Concatenate(const std::vector<dsize_t> &index, const std::shared_ptr<Tensor> &input);
|
||||
|
||||
// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
|
||||
// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
|
||||
// @tparam T type of values in the Tensor Iterator
|
||||
|
@ -518,6 +571,7 @@ class Tensor {
|
|||
// @return TensorIterator
|
||||
template <typename T>
|
||||
TensorIterator<T> begin() {
|
||||
AllocateBuffer(SizeInBytes());
|
||||
return TensorIterator<T>(data_);
|
||||
}
|
||||
|
||||
|
@ -529,7 +583,18 @@ class Tensor {
|
|||
return TensorIterator<T>(data_end_);
|
||||
}
|
||||
|
||||
// Copies the last dimension at `index` from Tensor `src` to this Tensor.
|
||||
// @param src Tensor
|
||||
// @param index vector to the start of the dimension. The last dim should be 0
|
||||
// @return Status
|
||||
Status CopyLastDimAt(const std::shared_ptr<Tensor> &src, const std::vector<dsize_t> &index);
|
||||
|
||||
protected:
|
||||
// Get the starting memory address for the data of the tensor. This potentially
|
||||
// drives an allocation if the data is null.
|
||||
// @return unsigned char*
|
||||
unsigned char *GetMutableBuffer();
|
||||
|
||||
// A function that prints Tensor recursively, first called by print
|
||||
// @param out
|
||||
// @param cur_dim
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/tensor_row.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {}
|
||||
|
||||
TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {}
|
||||
|
||||
TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {}
|
||||
|
||||
TensorRow::TensorRow(row_id_type id, const std::initializer_list<value_type> &lst) : id_(id), row_(lst) {}
|
||||
|
||||
TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {}
|
||||
|
||||
TensorRow &TensorRow::operator=(const TensorRow &tr) {
|
||||
if (this == &tr) {
|
||||
return *this;
|
||||
}
|
||||
row_ = tr.row_;
|
||||
id_ = tr.id_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorRow &TensorRow::operator=(const std::initializer_list<TensorRow::value_type> &lst) {
|
||||
row_ = lst;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {}
|
||||
|
||||
TensorRow::TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept
|
||||
: id_(id), row_(std::move(lst)) {}
|
||||
|
||||
TensorRow::TensorRow(TensorRow &&tr) noexcept {
|
||||
id_ = tr.id_;
|
||||
row_ = std::move(tr.row_);
|
||||
}
|
||||
|
||||
TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept {
|
||||
if (this == &tr) {
|
||||
return *this;
|
||||
}
|
||||
row_ = std::move(tr.row_);
|
||||
id_ = tr.id_;
|
||||
tr.id_ = kDefaultRowId;
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorRow &TensorRow::operator=(std::initializer_list<TensorRow::value_type> &&lst) noexcept {
|
||||
row_ = std::move(lst);
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_CORE_TENSOR_ROW_H_
|
||||
#define DATASET_CORE_TENSOR_ROW_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class TensorRow; // A set of Tensor pointers with an id
|
||||
using TensorTable = std::vector<TensorRow>; // The table of tensors is a vector of rows
|
||||
using TensorQTable = std::deque<TensorRow>; // A different flavour of tensor table, this one has queue functionality
|
||||
|
||||
class TensorRow {
|
||||
public:
|
||||
static constexpr row_id_type kDefaultRowId = -1; // Default row id
|
||||
|
||||
// Type definitions
|
||||
using size_type = dsize_t;
|
||||
using value_type = std::shared_ptr<Tensor>;
|
||||
using reference = std::shared_ptr<Tensor> &;
|
||||
using const_reference = const std::shared_ptr<Tensor> &;
|
||||
using vector_type = std::vector<std::shared_ptr<Tensor>>;
|
||||
using iterator = std::vector<std::shared_ptr<Tensor>>::iterator;
|
||||
using const_iterator = std::vector<std::shared_ptr<Tensor>>::const_iterator;
|
||||
|
||||
TensorRow() noexcept;
|
||||
|
||||
TensorRow(size_type n, value_type t) noexcept;
|
||||
|
||||
// Copy Constructors
|
||||
explicit TensorRow(const vector_type &v);
|
||||
|
||||
TensorRow(row_id_type id, const std::initializer_list<value_type> &lst);
|
||||
|
||||
TensorRow(const TensorRow &tr);
|
||||
|
||||
TensorRow &operator=(const TensorRow &tr);
|
||||
|
||||
TensorRow &operator=(const std::initializer_list<value_type> &lst);
|
||||
|
||||
// Move Constructors
|
||||
explicit TensorRow(vector_type &&v) noexcept;
|
||||
|
||||
TensorRow(row_id_type id, std::initializer_list<value_type> &&lst) noexcept;
|
||||
|
||||
TensorRow(TensorRow &&tr) noexcept;
|
||||
|
||||
TensorRow &operator=(TensorRow &&tr) noexcept;
|
||||
|
||||
TensorRow &operator=(std::initializer_list<value_type> &&lst) noexcept;
|
||||
|
||||
// Destructor
|
||||
~TensorRow() = default;
|
||||
|
||||
// Functions to fetch/set id/vector
|
||||
row_id_type getId() const { return id_; }
|
||||
|
||||
void setId(row_id_type id) { id_ = id; }
|
||||
|
||||
const vector_type &getRow() const { return row_; }
|
||||
|
||||
// Wrapper functions to support vector operations
|
||||
void emplace_back(value_type t) { row_.emplace_back(t); }
|
||||
|
||||
void push_back(value_type t) { row_.push_back(t); }
|
||||
|
||||
void clear() noexcept { row_.clear(); }
|
||||
|
||||
size_type size() const noexcept { return row_.size(); }
|
||||
|
||||
void reserve(size_type size) { row_.reserve(size); }
|
||||
|
||||
void resize(size_type size) { row_.resize(size); }
|
||||
|
||||
bool empty() { return row_.empty(); }
|
||||
|
||||
void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); }
|
||||
|
||||
// Wrapper functions to support vector element access
|
||||
reference at(size_type index) { return row_.at(index); }
|
||||
|
||||
const_reference at(size_type index) const { return row_.at(index); }
|
||||
|
||||
reference front() { return row_.front(); }
|
||||
|
||||
const_reference front() const { return row_.front(); }
|
||||
|
||||
reference back() { return row_.back(); }
|
||||
|
||||
const_reference back() const { return row_.back(); }
|
||||
|
||||
reference operator[](size_type index) { return row_[index]; }
|
||||
|
||||
const_reference operator[](size_type index) const { return row_[index]; }
|
||||
|
||||
// Wrapper functions to support vector iteration
|
||||
iterator begin() { return row_.begin(); }
|
||||
|
||||
const_iterator begin() const { return row_.begin(); }
|
||||
|
||||
iterator end() { return row_.end(); }
|
||||
|
||||
const_iterator end() const { return row_.end(); }
|
||||
|
||||
protected:
|
||||
row_id_type id_;
|
||||
std::vector<std::shared_ptr<Tensor>> row_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_CORE_TENSOR_ROW_H_
|
|
@ -94,7 +94,7 @@ class TensorShape {
|
|||
// @return
|
||||
TensorShape PrependDim(dsize_t dim) const;
|
||||
|
||||
// Insert a new dim at the end of the shape. For example, <2,4> --> PrependDim(4) --> <2,4,4>
|
||||
// Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4>
|
||||
// @param dim
|
||||
// @return
|
||||
TensorShape AppendDim(dsize_t dim) const;
|
||||
|
@ -118,7 +118,10 @@ class TensorShape {
|
|||
|
||||
bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); }
|
||||
|
||||
dsize_t operator[](const dsize_t index) const { return raw_shape_[index]; }
|
||||
dsize_t operator[](const dsize_t index) const {
|
||||
if (index < 0) return raw_shape_[raw_shape_.size() + index];
|
||||
return raw_shape_[index];
|
||||
}
|
||||
|
||||
// Return the Shape as a vector
|
||||
// @return
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_subdirectory(datasetops)
|
||||
add_subdirectory(opt)
|
||||
add_subdirectory(gnn)
|
||||
add_subdirectory(perf)
|
||||
if (ENABLE_TDTQUE)
|
||||
add_subdirectory(tdt)
|
||||
endif ()
|
||||
|
@ -16,7 +17,7 @@ add_library(engine OBJECT
|
|||
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
|
||||
if (ENABLE_TDTQUE)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
|
||||
else()
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
|
||||
endif ()
|
||||
|
|
|
@ -152,6 +152,23 @@ class Connector {
|
|||
return out;
|
||||
}
|
||||
|
||||
// Get current size of connector.
|
||||
int32_t size() const {
|
||||
int32_t size = 0;
|
||||
for (int32_t i = 0; i < queues_.size(); ++i) {
|
||||
size += queues_[i]->size();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
int32_t capacity() const {
|
||||
int32_t capacity = 0;
|
||||
for (int32_t i = 0; i < queues_.size(); ++i) {
|
||||
capacity += queues_[i]->capacity();
|
||||
}
|
||||
return capacity;
|
||||
}
|
||||
|
||||
// Register the internal resources with Task group for interruption service.
|
||||
// @param vg
|
||||
// @return
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
#include "dataset/util/allocator.h"
|
||||
#include "dataset/core/global_context.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/datasetops/source/storage_client.h"
|
||||
#include "dataset/engine/datasetops/source/tf_buffer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -26,37 +24,6 @@ namespace dataset {
|
|||
// Description: This is the main constructor that is used for making a buffer
|
||||
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
|
||||
|
||||
// Name: CreateDataBuffer()
|
||||
// Description: A static factory method to create the appropriate type of derived class
|
||||
// buffer. Returns the base class reference for DataBuffer.
|
||||
Status DataBuffer::CreateDataBuffer(
|
||||
int32_t id, // In: The id for the new buffer
|
||||
std::shared_ptr<StorageClient> storage_client, // In: The storage client that is related to this buffer type
|
||||
std::unique_ptr<DataBuffer> *ptr) {
|
||||
std::unique_ptr<DataBuffer> new_data_buffer;
|
||||
try {
|
||||
DatasetType ds_type = storage_client->schema()->dataset_type();
|
||||
switch (ds_type) {
|
||||
case DatasetType::kTf: {
|
||||
// This type of buffer is for TF record data.
|
||||
// Allocate derived class version for a TF buffers
|
||||
new_data_buffer = std::make_unique<TFBuffer>(id, kDeBFlagNone, storage_client);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
std::string errMsg("Invalid buffer type");
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
}
|
||||
}
|
||||
} catch (std::bad_alloc &e) {
|
||||
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what());
|
||||
} catch (std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
*ptr = std::move(new_data_buffer);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Name: print()
|
||||
// Description: A function that prints info about the DataBuffer (base class version)
|
||||
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
|
||||
|
@ -98,7 +65,7 @@ Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32
|
|||
|
||||
// Remove me!! Callers should fetch rows via pop
|
||||
Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const {
|
||||
if (row_id < tensor_table_->size()) {
|
||||
if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) {
|
||||
*ptr = tensor_table_->at(row_id);
|
||||
} else {
|
||||
std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id);
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue