From 947a93c8644a5b323b60d298a998f468b31c065f Mon Sep 17 00:00:00 2001 From: geekun Date: Tue, 23 Jun 2020 16:29:22 +0800 Subject: [PATCH 01/27] fix infer value bug --- mindspore/common/tensor.py | 8 ++++++++ mindspore/ops/operations/math_ops.py | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 0a631b954fd..ddd7cbfabc0 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -22,6 +22,10 @@ from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry __all__ = ['Tensor', 'MetaTensor'] +np_types = (np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, np.float16, + np.float32, np.float64, np.bool_) + class Tensor(Tensor_): @@ -54,6 +58,10 @@ class Tensor(Tensor_): """ def __init__(self, input_data, dtype=None): + # If input data is numpy number, convert it to np array + if isinstance(input_data, np_types): + input_data = np.array(input_data) + # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. check_type('tensor input_data', input_data, (Tensor_, float, int)) if dtype is not None: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 08cd481582d..5044783ea7e 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer): def infer_value(self, input_x): if input_x is not None: input_x = input_x.asnumpy() - return Tensor(-input_x) + out = np.array(-input_x, input_x.dtype) + return Tensor(out) return None @@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp): if x is not None and y is not None: x = x.asnumpy() y = y.asnumpy() - return Tensor(x / y) + out = np.array(x / y, x.dtype) + return Tensor(out) return None From 8e2bb7a85cbdceedc1c231ae9547a9cc7b28d68b Mon Sep 17 00:00:00 2001 From: caojian05 Date: Wed, 24 Jun 2020 13:36:04 +0800 Subject: [PATCH 02/27] fix accurancy lower then 92 --- model_zoo/vgg16/src/config.py | 4 +++- model_zoo/vgg16/train.py | 26 ++++++++++++++++---------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/model_zoo/vgg16/src/config.py b/model_zoo/vgg16/src/config.py index 8c6ffee98b4..a34cf7a1d3e 100644 --- a/model_zoo/vgg16/src/config.py +++ b/model_zoo/vgg16/src/config.py @@ -19,7 +19,9 @@ from easydict import EasyDict as edict cifar_cfg = edict({ 'num_classes': 10, - 'lr_init': 0.05, + 'lr_init': 0.01, + 'lr_max': 0.1, + 'warmup_epochs': 5, 'batch_size': 64, 'epoch_size': 70, 'momentum': 0.9, diff --git a/model_zoo/vgg16/train.py b/model_zoo/vgg16/train.py index c582cdd679d..33a4f0310c9 100644 --- a/model_zoo/vgg16/train.py +++ b/model_zoo/vgg16/train.py @@ -38,20 +38,25 @@ random.seed(1) np.random.seed(1) -def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): +def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): """Set learning rate.""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs - decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + warmup_steps = steps_per_epoch * warmup_epochs + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 for i in range(total_steps): - if i < decay_epoch_index[0]: - lr_each_step.append(lr_max) - elif i < decay_epoch_index[1]: - lr_each_step.append(lr_max * 0.1) - elif i < decay_epoch_index[2]: - lr_each_step.append(lr_max * 0.01) + if i < warmup_steps: + lr_value = float(lr_init) + inc_each_step * float(i) else: - lr_each_step.append(lr_max * 0.001) + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr_value = float(lr_max) * base * base + if lr_value < 0.0: + lr_value = 0.0 + lr_each_step.append(lr_value) + current_step = global_step lr_each_step = np.array(lr_each_step).astype(np.float32) learning_rate = lr_each_step[current_step:] @@ -86,7 +91,8 @@ if __name__ == '__main__': if args_opt.pre_trained: load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) - lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) + lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) From 6d491b90739e7dce1c8daf8aba217bcd05d95d87 Mon Sep 17 00:00:00 2001 From: Xian Weizhao Date: Wed, 24 Jun 2020 16:51:26 +0800 Subject: [PATCH 03/27] relax the exception of control depend on value node --- mindspore/ccsrc/transform/convert.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 32333a06ae5..3f6b31303c3 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); } if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it"; + MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; error_ = SUCCESS; } return true; @@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { }); } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); + } else if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; } else { MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() << " -> dst:" << dst_ops_list->size(); From 90639a2a44568269d33f9e1397b01eaa78e07fa9 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Fri, 19 Jun 2020 16:23:37 +0800 Subject: [PATCH 04/27] fix params KeyError in group params --- mindspore/nn/optim/optimizer.py | 23 ++++++++++++++++++++--- mindspore/ops/operations/debug_ops.py | 6 ------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 9bfc3a284b6..40642226adc 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -219,8 +219,28 @@ class Optimizer(Cell): raise TypeError("Learning rate should be float, Tensor or Iterable.") return lr + def _check_group_params(self, parameters): + """Check group params.""" + parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] + for group_param in parameters: + invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) + if invalid_key: + raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') + + if 'order_params' in group_param.keys(): + if len(group_param.keys()) > 1: + raise ValueError("The order params dict in group parameters should " + "only include the 'order_params' key.") + if not isinstance(group_param['order_params'], Iterable): + raise TypeError("The value of 'order_params' should be an Iterable type.") + continue + + if not group_param['params']: + raise ValueError("Optimizer got an empty group parameter list.") + def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" + self._check_group_params(parameters) if self.dynamic_lr: dynamic_lr_length = learning_rate.size() else: @@ -250,9 +270,6 @@ class Optimizer(Cell): if dynamic_lr_length not in (lr_length, 0): raise ValueError("The dynamic learning rate in group should be the same size.") - if not group_param['params']: - raise ValueError("Optimizer got an empty group parameter list.") - dynamic_lr_length = lr_length self.dynamic_lr_length = dynamic_lr_length diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index c4fbddd38e7..47b70688a89 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -302,12 +302,6 @@ class Print(PrimitiveWithInfer): Output tensor or string to stdout. Note: - The print operation cannot support the following cases currently. - - 1. The type of tensor is float64 or bool. - - 2. The data of tensor is a scalar type. - In pynative mode, please use python print function. Inputs: From c7d32e1e5536df03a4afdbad6b9dbace50eb678a Mon Sep 17 00:00:00 2001 From: Jesse Lee Date: Mon, 22 Jun 2020 13:29:15 -0400 Subject: [PATCH 05/27] CacheOp branch infrastructure --- mindspore/ccsrc/dataset/util/CMakeLists.txt | 6 + mindspore/ccsrc/dataset/util/allocator.h | 87 ++++ mindspore/ccsrc/dataset/util/buddy.cc | 388 ++++++++++++++++++ mindspore/ccsrc/dataset/util/buddy.h | 133 ++++++ mindspore/ccsrc/dataset/util/cache_pool.cc | 202 +++++++++ mindspore/ccsrc/dataset/util/cache_pool.h | 139 +++++++ mindspore/ccsrc/dataset/util/list.h | 18 + mindspore/ccsrc/dataset/util/memory_pool.h | 14 - mindspore/ccsrc/dataset/util/path.cc | 118 +++++- mindspore/ccsrc/dataset/util/path.h | 14 + mindspore/ccsrc/dataset/util/semaphore.cc | 41 ++ mindspore/ccsrc/dataset/util/semaphore.h | 54 +++ mindspore/ccsrc/dataset/util/slice.cc | 38 ++ mindspore/ccsrc/dataset/util/slice.h | 122 ++++++ .../ccsrc/dataset/util/storage_container.cc | 164 ++++++++ .../ccsrc/dataset/util/storage_container.h | 79 ++++ .../ccsrc/dataset/util/storage_manager.cc | 167 ++++++++ .../ccsrc/dataset/util/storage_manager.h | 76 ++++ mindspore/ccsrc/dataset/util/system_pool.h | 7 + 19 files changed, 1850 insertions(+), 17 deletions(-) create mode 100644 mindspore/ccsrc/dataset/util/buddy.cc create mode 100644 mindspore/ccsrc/dataset/util/buddy.h create mode 100644 mindspore/ccsrc/dataset/util/cache_pool.cc create mode 100644 mindspore/ccsrc/dataset/util/cache_pool.h create mode 100644 mindspore/ccsrc/dataset/util/semaphore.cc create mode 100644 mindspore/ccsrc/dataset/util/semaphore.h create mode 100644 mindspore/ccsrc/dataset/util/slice.cc create mode 100644 mindspore/ccsrc/dataset/util/slice.h create mode 100644 mindspore/ccsrc/dataset/util/storage_container.cc create mode 100644 mindspore/ccsrc/dataset/util/storage_container.h create mode 100644 mindspore/ccsrc/dataset/util/storage_manager.cc create mode 100644 mindspore/ccsrc/dataset/util/storage_manager.h diff --git a/mindspore/ccsrc/dataset/util/CMakeLists.txt b/mindspore/ccsrc/dataset/util/CMakeLists.txt index b36d612435a..96489add071 100644 --- a/mindspore/ccsrc/dataset/util/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/util/CMakeLists.txt @@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(utils OBJECT arena.cc + buddy.cc + cache_pool.cc circular_pool.cc memory_pool.cc cond_var.cc @@ -11,7 +13,11 @@ add_library(utils OBJECT service.cc services.cc lock.cc + semaphore.cc status.cc + storage_container.cc + storage_manager.cc + slice.cc path.cc wait_post.cc sig_handler.cc) diff --git a/mindspore/ccsrc/dataset/util/allocator.h b/mindspore/ccsrc/dataset/util/allocator.h index ba6c7786df5..50a9cadbe3f 100644 --- a/mindspore/ccsrc/dataset/util/allocator.h +++ b/mindspore/ccsrc/dataset/util/allocator.h @@ -17,8 +17,10 @@ #define DATASET_UTIL_ALLOCATOR_H_ #include +#include #include #include +#include #include "dataset/util/memory_pool.h" namespace mindspore { @@ -84,6 +86,91 @@ class Allocator { private: std::shared_ptr pool_; }; +/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will +/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. +/// Default to std::allocator +template > +class MemGuard { + public: + using allocator = C; + MemGuard() : n_(0) {} + explicit MemGuard(allocator a) : n_(0), alloc_(a) {} + // There is no copy constructor nor assignment operator because the memory is solely owned by this object. + MemGuard(const MemGuard &) = delete; + MemGuard &operator=(const MemGuard &) = delete; + // On the other hand, We can support move constructor + MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} + MemGuard &operator=(MemGuard &&lhs) noexcept { + if (this != &lhs) { + this->deallocate(); + n_ = lhs.n_; + alloc_ = std::move(lhs.alloc_); + ptr_ = std::move(lhs.ptr_); + } + return *this; + } + /// \brief Explicitly deallocate the memory if allocated + void deallocate() { + if (ptr_) { + auto *p = ptr_.release(); + if (!std::is_arithmetic::value && std::is_destructible::value) { + for (auto i = 0; i < n_; ++i) { + p[i].~T(); + } + } + alloc_.deallocate(p, n_); + n_ = 0; + } + } + /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is + /// allocated. + /// \param n Number of objects of type T to be allocated + /// \tparam Args Extra arguments pass to the constructor of T + template + Status allocate(size_t n, Args &&... args) noexcept { + try { + deallocate(); + if (n > 0) { + T *data = alloc_.allocate(n); + if (!std::is_arithmetic::value) { + for (auto i = 0; i < n; i++) { + std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); + } + } + ptr_ = std::unique_ptr(data); + n_ = n; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); + } + ~MemGuard() noexcept { deallocate(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetPointer() const { return ptr_.get(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetMutablePointer() { return ptr_.get(); } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) { return GetMutablePointer() + x; } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) const { return GetPointer() + x; } + /// \brief Return how many bytes are allocated in total + /// \return Number of bytes allocated in total + size_t GetSizeInBytes() const { return n_ * sizeof(T); } + + private: + allocator alloc_; + std::unique_ptr> ptr_; + size_t n_; +}; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/buddy.cc b/mindspore/ccsrc/dataset/util/buddy.cc new file mode 100644 index 00000000000..3a14258419a --- /dev/null +++ b/mindspore/ccsrc/dataset/util/buddy.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/buddy.h" +#include +#include +#include "dataset/util/de_error.h" +#include "dataset/util/memory_pool.h" +#include "dataset/util/system_pool.h" +#include "./securec.h" + +inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } + +inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } + +inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } + +inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } + +inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } + +namespace mindspore { +namespace dataset { +Status BuddySpace::Init() { + if (log_min_ < 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "log_min must be positive : " + std::to_string(log_min_)); + } + if (num_lvl_ < 3 || num_lvl_ > 18) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); + } + min_ = BitLeftShift(1, log_min_); + max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); + size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; + size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; + size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; + RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); + hint_ = reinterpret_cast(ptr_); + count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); + map_ = reinterpret_cast(ptr_) + offset_2; + count_[num_lvl_ - 1] = 1; + map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); + return Status::OK(); +} + +Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { + std::lock_guard lock(mutex_); + addr_t addr = AllocNoLock(sz, desc); + if (addr != NOSPACE) { + *p = addr; + return Status::OK(); + } else { + return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); + } +} + +addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { + DS_ASSERT(sz <= max_); + uint32_t reqSize = SizeToBlock(sz); + rel_addr_t rel_addr = AllocBuddySeg(reqSize); + if (rel_addr != static_cast(NOSPACE)) { + (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); + desc->sig = static_cast(0xDEADBEEF); + desc->addr = rel_addr; + desc->req_size = reqSize; + desc->blk_size = NextPowerOf2(reqSize); + return static_cast(rel_addr * min_); + } else { + return NOSPACE; + } +} + +void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { + DS_ASSERT(desc->sig == 0XDEADBEEF); + rel_addr_t rel_addr = desc->addr; + size_t blk_size = desc->blk_size; + size_t req_size = desc->req_size; + FreeBuddySeg(rel_addr, blk_size, req_size); +} + +void BuddySpace::Free(const BSpaceDescriptor *desc) { + std::lock_guard lock(mutex_); + return FreeNoLock(desc); +} + +std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { + os << "1 unit = " << s.GetMinSize() << "\n" + << "Size of buddy space = " << s.GetMaxSize() << "\n" + << "Number of levels = " << s.num_lvl_ << "\n\n" + << "Percent free = " << s.PercentFree() << "\n" + << "Dumping count array : " + << "\n"; + for (int i = 0; i < s.num_lvl_; i++) { + os << "[" << i << "] = " << s.count_[i] << " "; + if (((i + 1) % 4) == 0) { + os << "\n"; + } + } + os << "\n"; + os << "Dumping allocation info:" + << "\n"; + auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); + rel_addr_t addr = 0; + while (addr < max_addr) { + size_t sz = 0; + BuddySpace::STATE st; + s.GetBuddySegState(addr, &sz, &st); + os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " + << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) + << "\n"; + addr += sz; + } + return os; +} + +void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { + char byte; + int pos; + int offset; + uint64_t val = 0; + int shift; + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + byte = map_[pos]; + switch (offset) { + case 0: + val = byte; + break; + case 1: + case 3: + if (offset == 1) { + val = BitLeftShift(BitAnd(byte, 0x30), shift); + } else { + val = BitLeftShift(BitAnd(byte, 0x03), shift); + } + break; + case 2: + val = BitLeftShift(BitAnd(byte, 0x0F), shift); + break; + } + if (BitAnd(val, ONE_BIT)) { + *rel_sz = 1; + } else if (BitAnd(val, TWO_BIT)) { + *rel_sz = 2; + } else if (BitAnd(val, MORE_BIT)) { + log_t lg = BitAnd(val, 0x0F); + *rel_sz = BitLeftShift(1, lg + 2); + } else { + *st = STATE::kEmpty; + return; + } + *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; +} + +void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { + int clr; + int mask; + int pos; + int offset; + int val = 0; + int shift; + auto log_sz = static_cast(Log2(rel_sz)); + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + if (rel_sz == 1) { + val = ONE_BIT; + mask = 0xC0; + } else if (rel_sz == 2) { + val = TWO_BIT; + mask = 0xF0; + } else { + val = BitOr(log_sz - 2, MORE_BIT); + mask = 0xFF; + } + if (st == STATE::kAlloc) { + val = BitOr(val, ALLOC_BIT); + } else if (st == STATE::kFree) { + val = BitAnd(val, ~(static_cast(ALLOC_BIT))); + } else if (st == STATE::kEmpty) { + val = 0; + } + clr = static_cast(~(BitRightShift(mask, shift))); + map_[pos] = static_cast(BitAnd(map_[pos], clr)); + map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); + if (st == STATE::kAlloc) { + count_[log_sz]--; + } else if (st == STATE::kFree) { + count_[log_sz]++; + if (rel_addr < hint_[log_sz]) { + hint_[log_sz] = rel_addr; + } + } +} + +void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { + while (blk_sz < BitLeftShift(1, num_lvl_)) { + rel_addr_t buddy = BitEx(addr, blk_sz); + size_t sz = 0; + STATE st; + GetBuddySegState(buddy, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + auto log_sz = static_cast(Log2(blk_sz)); + rel_addr_t left = (buddy < addr) ? buddy : addr; + rel_addr_t right = left + blk_sz; + DS_ASSERT(count_[log_sz] >= 2); + count_[log_sz] -= 2; + SetBuddySegState(right, blk_sz, STATE::kEmpty); + SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); + for (int i = 0; i < log_sz; i++) { + if (hint_[i] == right) { + hint_[i] = left; + } + } + addr = left; + blk_sz <<= 1u; + } else { + break; + } + } +} + +void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + DS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + count_[i]--; + SetBuddySegState(addr, half_sz, STATE::kFree); + SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); + if (remaining_sz >= half_sz) { + SetBuddySegState(addr, half_sz, STATE::kAlloc); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + break; + } + addr += half_sz; + } + } +} + +void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + DS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + if (remaining_sz >= half_sz) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + DS_ASSERT(sz == half_sz && st == STATE::kAlloc); + } +#endif + SetBuddySegState(addr, half_sz, STATE::kFree); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + JoinBuddySeg(addr, half_sz); + break; + } + addr += half_sz; + } + } +} + +rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { + uint32_t blk_size = NextPowerOf2(req_size); + int start_inx = static_cast(Log2(blk_size)); + bool found = false; + rel_addr_t ask_addr = 0; + auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); + STATE st; + size_t sz = 0; + for (int i = start_inx; !found && i < num_lvl_; i++) { + DS_ASSERT(count_[i] >= 0); + if (count_[i] == 0) { + continue; + } + auto blk_sz = static_cast(BitLeftShift(1, i)); + ask_addr = hint_[i]; + while (ask_addr < max_addr && !found) { + GetBuddySegState(ask_addr, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + found = true; + } else { + DS_ASSERT(st != STATE::kEmpty); + ask_addr += ((sz > blk_sz) ? sz : blk_sz); + } + } + } + if (found) { + if (sz > req_size) { + TrimBuddySeg(ask_addr, sz, req_size); + } else { + SetBuddySegState(ask_addr, sz, STATE::kAlloc); + hint_[start_inx] = ask_addr; + } + return ask_addr; + } else { + return static_cast(NOSPACE); + } +} + +void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { + if (req_size == blk_size) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + } +#endif + SetBuddySegState(addr, blk_size, STATE::kFree); + JoinBuddySeg(addr, blk_size); + } else { + UnTrimBuddySeg(addr, blk_size, req_size); + } +} + +int BuddySpace::PercentFree() const { + uint64_t total_free_sz = 0; + uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); + // Go through the count array without lock + for (int i = 0; i < num_lvl_; i++) { + int cnt = count_[i]; + if (cnt == 0) { + continue; + } + uint64_t blk_sz = BitLeftShift(1, i); + total_free_sz += (blk_sz * cnt); + } + return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); +} + +BuddySpace::BuddySpace(int log_min, int num_lvl) + : hint_(nullptr), + count_(nullptr), + map_(nullptr), + log_min_(log_min), + num_lvl_(num_lvl), + min_(0), + max_(0), + ptr_(nullptr) {} + +BuddySpace::~BuddySpace() { + if (ptr_ != nullptr) { + free(ptr_); + } + hint_ = nullptr; + count_ = nullptr; + map_ = nullptr; +} + +Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { + Status rc; + auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); + if (bs == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = bs->Init(); + if (rc.IsOk()) { + (*out_bs).reset(bs); + } else { + delete bs; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/buddy.h b/mindspore/ccsrc/dataset/util/buddy.h new file mode 100644 index 00000000000..08c05cbbdbe --- /dev/null +++ b/mindspore/ccsrc/dataset/util/buddy.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_BUDDY_H_ +#define DATASET_UTIL_BUDDY_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/util/status.h" + +using addr_t = int64_t; +using rel_addr_t = int32_t; +using log_t = int; +#define ALLOC_BIT 0x80 +#define ONE_BIT 0x40 +#define TWO_BIT 0x20 +#define MORE_BIT 0x10 +#define NOSPACE ((addr_t)(-1)) +namespace mindspore { +namespace dataset { +struct BSpaceDescriptor { + int32_t sig; + rel_addr_t addr; + size_t req_size; + size_t blk_size; +}; + +class BuddySpace { + public: + // C++11 feature. Change STATE into a type safe class with + // the keyword. Don't take out the keyword 'class' + enum class STATE { kFree, kAlloc, kEmpty }; + + BuddySpace(const BuddySpace &) = delete; + + BuddySpace &operator=(const BuddySpace &) = delete; + + virtual ~BuddySpace(); + + Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; + + void Free(const BSpaceDescriptor *desc); + + uint64_t GetMinSize() const { return min_; } + + uint64_t GetMaxSize() const { return max_; } + + int PercentFree() const; + + friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); + + static uint64_t NextPowerOf2(uint64_t n) { + if (n <= 1) { + return 1; + } + n = n - 1; + while (n & (n - 1)) { + n = n & (n - 1); + } + return n << 1; + } + + static uint32_t Log2(uint64_t n) { + uint32_t cnt = 0; + while (n >>= 1) { + cnt++; + } + return cnt; + } + + static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); + + private: + rel_addr_t *hint_; + int *count_; + char *map_; + int log_min_; + int num_lvl_; + uint64_t min_; + uint64_t max_; + void *ptr_; + std::mutex mutex_; + + explicit BuddySpace(int log_min = 15, int num_lvl = 18); + + Status Init(); + + addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; + + void FreeNoLock(const BSpaceDescriptor *desc); + + uint32_t SizeToBlock(const uint64_t sz) const { + uint32_t reqSize = (sz / min_); + if (sz % min_) { + reqSize++; + } + return reqSize; + } + + void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; + + void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); + + void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); + + void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; + + void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/dataset/util/cache_pool.cc b/mindspore/ccsrc/dataset/util/cache_pool.cc new file mode 100644 index 00000000000..92504cd0634 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/cache_pool.cc @@ -0,0 +1,202 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/utils.h" +#include "dataset/util/cache_pool.h" +#include "dataset/util/services.h" + +namespace mindspore { +namespace dataset { +CachePool::CachePool(const value_allocator &alloc, const std::string &root) + : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} + +Status CachePool::DoServiceStart() { + tree_ = std::make_shared(); + // If we are given a disk path, set up the StorageManager + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + sm_ = std::make_shared(spill); + RETURN_IF_NOT_OK(sm_->ServiceStart()); + MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); + } + return Status::OK(); +} +Status CachePool::DoServiceStop() { + Status rc; + Status rc2; + if (sm_ != nullptr) { + rc = sm_->ServiceStop(); + if (rc.IsError()) { + rc2 = rc; + } + } + sm_.reset(); + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, bl.sz); + } + } + tree_.reset(); + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + auto it = Path::DirIterator::OpenDirectory(&spill); + while (it->hasNext()) { + rc = it->next().Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + rc = spill.Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + return rc2; +} +CachePool::~CachePool() noexcept { (void)ServiceStop(); } +Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { + DataLocator bl; + Status rc; + size_t sz = 0; + // We will consolidate all the slices into one piece. + for (auto &v : buf) { + sz += v.GetSize(); + } + bl.sz = sz; + try { + bl.ptr = alloc_.allocate(sz); + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); + if (rc.IsError()) { + break; + } + pos += v.GetSize(); + } + if (rc.IsError()) { + alloc_.deallocate(bl.ptr, sz); + bl.ptr = nullptr; + return rc; + } + } catch (std::bad_alloc &e) { + if (sm_ != nullptr) { + RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); + // We have an assumption 0 is not a valid key from the design of AutoIndexObj. + // Make sure it is not 0. + if (bl.storage_key == 0) { + RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); + } + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + rc = tree_->insert(bl, key); + if (rc.IsError() && bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, sz); + } + return rc; +} +Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + if (it->ptr != nullptr) { + ReadableSlice src(it->ptr, it->sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); + } else if (sm_ != nullptr) { + size_t expectedLength = 0; + RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); + if (expectedLength != it->sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + if (bytesRead != nullptr) { + *bytesRead = it->sz; + } + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} +const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } +Path CachePool::GetSpillPath() const { + auto spill = Path(root_) / subfolder_; + return spill; +} +CachePool::CacheStat CachePool::GetStat() const { + CacheStat cs{0}; + for (auto &it : *tree_) { + if (it.ptr != nullptr) { + ++cs.num_mem_cached; + } else { + ++cs.num_disk_cached; + } + } + return cs; +} +Status CachePool::Spill(CachePool::DataLocator *dl) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to spill"); + } + RETURN_UNEXPECTED_IF_NULL(dl); + RETURN_UNEXPECTED_IF_NULL(dl->ptr); + if (dl->storage_key == 0) { + ReadableSlice data(dl->ptr, dl->sz); + RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); + } + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return Status::OK(); +} +Status CachePool::Locate(CachePool::DataLocator *dl) { + RETURN_UNEXPECTED_IF_NULL(dl); + if (dl->ptr == nullptr) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); + } + try { + dl->ptr = alloc_.allocate(dl->sz); + WritableSlice dest(dl->ptr, dl->sz); + Status rc = Read(dl->storage_key, &dest); + if (rc.IsError()) { + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return rc; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + return Status::OK(); +} +size_t CachePool::GetSize(CachePool::key_type key) const { + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + return it->sz; + } else { + return 0; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cache_pool.h b/mindspore/ccsrc/dataset/util/cache_pool.h new file mode 100644 index 00000000000..d35617d0e4b --- /dev/null +++ b/mindspore/ccsrc/dataset/util/cache_pool.h @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CACHE_POOL_H_ +#define DATASET_UTIL_CACHE_POOL_H_ + +#include +#include +#include +#include +#include "dataset/util/allocator.h" +#include "dataset/util/service.h" +#include "dataset/util/slice.h" +#include "dataset/util/storage_manager.h" +#include "dataset/util/auto_index.h" + +namespace mindspore { +namespace dataset { +/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of +/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to +/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to +/// restore the buffer. +/// \see ReadableSlice +class CachePool : public Service { + public: + using base_type = uint8_t; + using pointer = base_type *; + using const_pointer = const base_type *; + using reference = base_type &; + using const_reference = const base_type &; + using value_allocator = Allocator; + + // An internal class to locate the whereabouts of a backed up buffer which can be either in + class DataLocator { + public: + DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} + ~DataLocator() = default; + DataLocator(const DataLocator &other) = default; + DataLocator &operator=(const DataLocator &other) = default; + DataLocator(DataLocator &&other) noexcept { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + DataLocator &operator=(DataLocator &&other) noexcept { + if (&other != this) { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + return *this; + } + pointer ptr; + size_t sz; + StorageManager::key_type storage_key; + }; + + using data_index = AutoIndexObj; + using key_type = data_index::key_type; + using bl_alloc_type = typename value_allocator::template rebind::other; + + /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and + /// how many elements are spilled to disk. + struct CacheStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + }; + + /// \brief Constructor + /// \param alloc Allocator to allocate memory from + /// \param root Optional disk folder to spill + explicit CachePool(const value_allocator &alloc, const std::string &root = ""); + + CachePool(const CachePool &) = delete; + CachePool(CachePool &&) = delete; + CachePool &operator=(const CachePool &) = delete; + CachePool &operator=(CachePool &&) = delete; + ~CachePool() noexcept; + + Status DoServiceStart() override; + Status DoServiceStop() override; + + Path GetSpillPath() const; + + /// \brief Insert a sequence of ReadableSlice objects into the pool. + /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. + /// \param[in] buf A sequence of ReadableSlice objects. + /// \param[out] key Generated key + /// \return Error code + Status Insert(const std::vector &buf, key_type *key); + /// \brief Restore a cached buffer (from memory or disk) + /// \param[in] key A previous key returned from Insert + /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice + /// \param[out] bytesRead Optional. Number of bytes read. + /// \return Error code + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; + + Status Spill(DataLocator *dl); + + Status Locate(DataLocator *dl); + + size_t GetSize(key_type key) const; + + /// \brief Get statistics. + /// \return CacheStat object + CacheStat GetStat() const; + + const value_allocator &get_allocator() const; + + std::string MyName() const { return subfolder_; } + + private: + value_allocator alloc_; + Path root_; + const std::string subfolder_; + std::shared_ptr sm_; + std::shared_ptr tree_; +}; +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/dataset/util/list.h b/mindspore/ccsrc/dataset/util/list.h index 5a08f4514e5..a4c15daa0e4 100644 --- a/mindspore/ccsrc/dataset/util/list.h +++ b/mindspore/ccsrc/dataset/util/list.h @@ -106,6 +106,24 @@ struct List { ++count; } + // Insert elem2 before elem1 in the list. + virtual void InsertBefore(pointer elem1, pointer elem2) { + DS_ASSERT(elem1 != elem2); + Node &elem1_node = elem1->*node; + Node &elem2_node = elem2->*node; + elem2_node.next = elem1; + elem2_node.prev = elem1_node.prev; + if (elem1_node.prev != nullptr) { + Node &prev_node = elem1_node.prev->*node; + prev_node.next = elem2; + } + elem1_node.prev = elem2; + if (head == elem1) { + head = elem2; + } + ++count; + } + // Remove an element in the list virtual void Remove(pointer elem) noexcept { Node &elem_node = elem->*node; diff --git a/mindspore/ccsrc/dataset/util/memory_pool.h b/mindspore/ccsrc/dataset/util/memory_pool.h index 70876a81417..ee1da3bda15 100644 --- a/mindspore/ccsrc/dataset/util/memory_pool.h +++ b/mindspore/ccsrc/dataset/util/memory_pool.h @@ -44,20 +44,6 @@ class MemoryPool { virtual ~MemoryPool() {} }; -// Used by unique_ptr -template -class Deleter { - public: - explicit Deleter(std::shared_ptr &mp) : mp_(mp) {} - - ~Deleter() = default; - - void operator()(T *ptr) const { mp_->Deallocate(ptr); } - - private: - std::shared_ptr mp_; -}; - Status DeMalloc(std::size_t s, void **p, bool); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.cc b/mindspore/ccsrc/dataset/util/path.cc index c37fdc17f1d..59e5e5232c5 100644 --- a/mindspore/ccsrc/dataset/util/path.cc +++ b/mindspore/ccsrc/dataset/util/path.cc @@ -16,6 +16,8 @@ #include "dataset/util/path.h" #include +#include +#include #include #include #include @@ -26,7 +28,7 @@ namespace mindspore { namespace dataset { -#ifdef _WIN32 +#if defined(_WIN32) || defined(_WIN64) char Path::separator_ = '\\'; #else char Path::separator_ = '/'; @@ -132,7 +134,7 @@ Status Path::CreateDirectory() { #if defined(_WIN32) || defined(_WIN64) int rc = mkdir(common::SafeCStr(path_)); #else - int rc = mkdir(common::SafeCStr(path_), 0700); + int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); #endif if (rc) { std::ostringstream oss; @@ -182,6 +184,111 @@ Status Path::CreateDirectories() { return Status::OK(); } +Status Path::Remove() { + if (Exists()) { + if (IsDirectory()) { + errno_t err = rmdir(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } else { + errno_t err = unlink(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete file " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } + } + return Status::OK(); +} + +Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } + +Status Path::OpenFile(int *file_descriptor, bool create) { + int fd; + if (file_descriptor == nullptr) { + RETURN_STATUS_UNEXPECTED("null pointer"); + } + if (IsDirectory()) { + std::ostringstream oss; + oss << "Unable to create file " << path_ << " which is a directory."; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + // Convert to canonical form. + if (strlen(common::SafeCStr(path_)) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + char canonical_path[PATH_MAX + 1] = {0x00}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { +#endif + if (errno == ENOENT && create) { + // File doesn't exist and we are to create it. Let's break it down. + auto file_part = Basename(); + auto parent_part = ParentPath(); +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { +#endif + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto cur_inx = strlen(canonical_path); + if ((cur_inx + file_part.length() + 1) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + canonical_path[cur_inx++] = separator_; + if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != + EOK) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } + if (create) { + fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); + } else { + fd = open(canonical_path, O_RDWR); + } + if (fd == -1) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + *file_descriptor = fd; + return Status::OK(); +} + +Status Path::CloseFile(int fd) const { + if (close(fd) < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + return Status::OK(); +} + +Status Path::TruncateFile(int fd) const { + int rc; + rc = ftruncate(fd, 0); + if (rc == 0) { + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } +} + +std::string Path::Basename() { + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + return path_.substr(found + 1); + } else { + return path_; + } +} + std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { auto it = new (std::nothrow) DirIterator(f); @@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() { Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; - dp_ = opendir(common::SafeCStr(f->toString())); + dp_ = opendir(f->toString().c_str()); } bool Path::DirIterator::hasNext() { @@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() { } Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } + +std::ostream &operator<<(std::ostream &os, const Path &s) { + os << s.path_; + return os; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.h b/mindspore/ccsrc/dataset/util/path.h index efe01a7d16c..fbf65b8c236 100644 --- a/mindspore/ccsrc/dataset/util/path.h +++ b/mindspore/ccsrc/dataset/util/path.h @@ -90,6 +90,20 @@ class Path { std::string ParentPath(); + Status Remove(); + + Status CreateFile(int *fd); + + Status OpenFile(int *fd, bool create = false); + + Status CloseFile(int fd) const; + + Status TruncateFile(int fd) const; + + std::string Basename(); + + friend std::ostream &operator<<(std::ostream &os, const Path &s); + private: static char separator_; std::string path_; diff --git a/mindspore/ccsrc/dataset/util/semaphore.cc b/mindspore/ccsrc/dataset/util/semaphore.cc new file mode 100644 index 00000000000..36ddf5511d9 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/semaphore.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/semaphore.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +Status Semaphore::P() { + std::unique_lock lck(mutex_); + RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); + --value_; + return Status::OK(); +} +void Semaphore::V() { + std::unique_lock lck(mutex_); + ++value_; + wait_cond_.NotifyOne(); +} +int Semaphore::Peek() { + std::unique_lock lck(mutex_); + return value_; +} +Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } +Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } +void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/semaphore.h b/mindspore/ccsrc/dataset/util/semaphore.h new file mode 100644 index 00000000000..07b9e83e7fb --- /dev/null +++ b/mindspore/ccsrc/dataset/util/semaphore.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SEMAPHORE_H_ +#define DATASET_UTIL_SEMAPHORE_H_ + +#include "dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be +/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. +class Semaphore { + public: + /// \brief Constructor + /// \param init Initial value of the internal counter. + explicit Semaphore(int init) : value_(init) {} + + virtual ~Semaphore() {} + /// \brief Decrement the internal counter. Will be blocked if the value is 0. + /// \return Error code. Can get interrupt. + Status P(); + /// \brief Increment the internal counter. Wakeup on of the watiers if any. + void V(); + /// \brief Peek the internal value + /// \return The internal value + int Peek(); + Status Register(TaskGroup *vg); + Status Deregister(); + void ResetIntrpState(); + + private: + int value_; + + std::mutex mutex_; + CondVar wait_cond_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/dataset/util/slice.cc b/mindspore/ccsrc/dataset/util/slice.cc new file mode 100644 index 00000000000..f1798b4f44a --- /dev/null +++ b/mindspore/ccsrc/dataset/util/slice.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { + mutable_data_ = static_cast(src.mutable_data_) + offset; +} +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) + : WritableSlice(src, offset, src.GetSize() - offset) {} +Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { + RETURN_UNEXPECTED_IF_NULL(dest); + RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); + if (dest->GetSize() <= 0) { + RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); + } + auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); + if (err) { + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.h b/mindspore/ccsrc/dataset/util/slice.h new file mode 100644 index 00000000000..127df23cfab --- /dev/null +++ b/mindspore/ccsrc/dataset/util/slice.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SLICE_H_ +#define DATASET_UTIL_SLICE_H_ + +#include +#include +#include +#include "./securec.h" +#include "dataset/util/allocator.h" +#include "dataset/util/status.h" +namespace mindspore { +namespace dataset { +/// \brief A ReadableSlice wraps a const pointer in memory and its size. +/// \see WritableSlice for a non-const version +/// +class ReadableSlice { + public: + ReadableSlice() : ptr_(nullptr), sz_(0) {} + ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} + ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { + ptr_ = static_cast(src.GetPointer()) + offset; + sz_ = len; + } + ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} + ReadableSlice(const ReadableSlice &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + ReadableSlice &operator=(const ReadableSlice &lhs) { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + return *this; + } + ReadableSlice(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + } + ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + return *this; + } + /// \brief Getter function + /// \return Const version of the pointer + const void *GetPointer() const { return ptr_; } + /// \brief Getter function + /// \return Size of the slice + size_t GetSize() const { return sz_; } + bool empty() const { return ptr_ == nullptr; } + + private: + const void *ptr_; + size_t sz_; +}; +/// \brief A WritableSlice inherits from ReadableSlice to allow +/// one to write to the address pointed to by the pointer. +/// +class WritableSlice : public ReadableSlice { + public: + friend class StorageContainer; + /// \brief Default constructor + WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} + /// \brief This form of a constructor takes a pointer and its size. + WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} + WritableSlice(const WritableSlice &src, off64_t offset, size_t len); + WritableSlice(const WritableSlice &src, off64_t offset); + WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } + WritableSlice &operator=(const WritableSlice &lhs) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + ReadableSlice::operator=(lhs); + } + return *this; + } + WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + } + } + WritableSlice &operator=(WritableSlice &&lhs) noexcept { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + ReadableSlice::operator=(std::move(lhs)); + } + return *this; + } + /// \brief Copy the content from one slice onto another. + static Status Copy(WritableSlice *dest, const ReadableSlice &src); + + private: + void *mutable_data_; + void *GetMutablePointer() { return mutable_data_; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_container.cc b/mindspore/ccsrc/dataset/util/storage_container.cc new file mode 100644 index 00000000000..96f5b45d0cc --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_container.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/storage_container.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "dataset/util/de_error.h" +#include "dataset/util/path.h" +#include "dataset/util/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status StorageContainer::Create() { + RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); + RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); + is_open_ = true; + MS_LOG(INFO) << "Container " << cont_ << " created"; + return Status::OK(); +} + +Status StorageContainer::Open() noexcept { + std::lock_guard lck(mutex_); + // Check again + if (!is_open_) { + RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); + is_open_ = true; + } + return Status::OK(); +} + +Status StorageContainer::Close() noexcept { + if (is_open_) { + std::lock_guard lck(mutex_); + // Check again + if (is_open_) { + RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); + is_open_ = false; + fd_ = -1; + } + } + return Status::OK(); +} + +Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { + DS_ASSERT(is_open_); + RETURN_UNEXPECTED_IF_NULL(dest); + auto sz = dest->GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pread64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = read(fd_, dest->GetMutablePointer(), sz); +#else + auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { + DS_ASSERT(is_open_); + auto sz = dest.GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pwrite64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = write(fd_, dest.GetPointer(), sz); +#else + auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + if (sz > bs_->GetMaxSize()) { + RETURN_STATUS_UNEXPECTED("Request size too big"); + } + BSpaceDescriptor bspd{0}; + addr_t addr = 0; + RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); + *offset = static_cast(addr); + // We will do piecewise copy of the data to disk. + for (auto &v : buf) { + RETURN_IF_NOT_OK(Write(v, addr)); + addr += v.GetSize(); + } + return Status::OK(); +} + +Status StorageContainer::Truncate() const noexcept { + if (is_open_) { + RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); + MS_LOG(INFO) << "Container " << cont_ << " truncated"; + } + return Status::OK(); +} + +StorageContainer::~StorageContainer() noexcept { + (void)Truncate(); + (void)Close(); +} + +std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { + os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); + return os; +} + +Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { + Status rc; + auto sc = new (std::nothrow) StorageContainer(path); + if (sc == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = sc->Create(); + if (rc.IsOk()) { + (*out_sc).reset(sc); + } else { + delete sc; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.h b/mindspore/ccsrc/dataset/util/storage_container.h new file mode 100644 index 00000000000..07e41bd66a7 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_container.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ +#define DATASET_UTIL_STORAGE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/util/system_pool.h" +#include "dataset/util/buddy.h" +#include "dataset/util/path.h" +#include "dataset/util/slice.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class StorageManager; + +class StorageContainer { + public: + friend class StorageManager; + + ~StorageContainer() noexcept; + + StorageContainer(const StorageContainer &) = delete; + + StorageContainer &operator=(const StorageContainer &) = delete; + + friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); + + Status Open() noexcept; + + Status Close() noexcept; + + Status Insert(const std::vector &buf, off64_t *offset) noexcept; + + Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; + + Status Read(WritableSlice *dest, off64_t offset) const noexcept; + + Status Truncate() const noexcept; + + bool IsOpen() const { return is_open_; } + + static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); + + private: + mutable std::mutex mutex_; + Path cont_; + int fd_; + bool is_open_; + std::unique_ptr bs_; + + // Use the default value of BuddySpace + // which can map upto 4G of space. + explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} + + Status Create(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_manager.cc b/mindspore/ccsrc/dataset/util/storage_manager.cc new file mode 100644 index 00000000000..8b7a6044e93 --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_manager.cc @@ -0,0 +1,167 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/storage_manager.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "dataset/util/path.h" +#include "dataset/util/services.h" +#include "dataset/util//de_error.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { + std::ostringstream oss; + oss << prefix << std::setfill('0') << std::setw(5) << file_id; + return oss.str(); +} + +std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { + std::string base_name = GetBaseName(prefix, file_id); + return (base_name + "." + suffix); +} + +Status StorageManager::AddOneContainer() { + const std::string kPrefix = "IMG"; + const std::string kSuffix = "LB"; + Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); + std::shared_ptr sc; + RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); + containers_.push_back(sc); + file_id_++; + return Status::OK(); +} + +Status StorageManager::DoServiceStart() { + containers_.reserve(1000); + if (root_.IsDirectory()) { + RETURN_IF_NOT_OK(AddOneContainer()); + } else { + RETURN_STATUS_UNEXPECTED("Not a directory"); + } + return Status::OK(); +} + +Status StorageManager::Write(key_type *key, const std::vector &buf) { + RETURN_UNEXPECTED_IF_NULL(key); + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + std::shared_ptr cont; + key_type out_key; + value_type out_value; + bool create_new_container = false; + do { + SharedLock lock_s(&rw_lock_); + size_t num_containers = containers_.size(); + if (create_new_container) { + // Upgrade to exclusvie lock. + lock_s.Upgrade(); + create_new_container = false; + // Check again if someone has already added a + // new container after we got the x lock + if (containers_.size() == num_containers) { + RETURN_IF_NOT_OK(AddOneContainer()); + } + // Refresh how many containers there are. + num_containers = containers_.size(); + // Downgrade back to shared lock + lock_s.Downgrade(); + } + if (num_containers == 0) { + RETURN_STATUS_UNEXPECTED("num_containers is zero"); + } + // Go to the last container to insert. + cont = containers_.at(num_containers - 1); + off64_t offset; + Status rc = cont->Insert(buf, &offset); + if (rc.IsNoSpace()) { + create_new_container = true; + } else if (rc.IsOk()) { + out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); + RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); + *key = out_key; + break; + } else { + return rc; + } + } while (true); + return Status::OK(); +} + +Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = index_.Search(key); + if (r.second) { + auto &it = r.first; + value_type v = *it; + int container_inx = v.first; + off_t offset = v.second.first; + size_t sz = v.second.second; + if (dest->GetSize() < sz) { + std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + + " but length = " + std::to_string(dest->GetSize()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + if (bytesRead != nullptr) { + *bytesRead = sz; + } + auto cont = containers_.at(container_inx); + RETURN_IF_NOT_OK(cont->Read(dest, offset)); + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} + +Status StorageManager::DoServiceStop() noexcept { + Status rc; + Status rc1; + for (auto const &p : containers_) { + // The destructor of StorageContainer is not called automatically until the use + // count drops to 0. But it is not always the case. We will do it ourselves. + rc = p.get()->Truncate(); + if (rc.IsError()) { + rc1 = rc; + } + } + containers_.clear(); + file_id_ = 0; + return rc1; +} + +StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} + +StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } + +std::ostream &operator<<(std::ostream &os, const StorageManager &s) { + os << "Dumping all containers ..." + << "\n"; + for (auto const &p : s.containers_) { + os << *(p.get()); + } + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_manager.h b/mindspore/ccsrc/dataset/util/storage_manager.h new file mode 100644 index 00000000000..075ac713d2c --- /dev/null +++ b/mindspore/ccsrc/dataset/util/storage_manager.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ +#define DATASET_UTIL_STORAGE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "dataset/util/allocator.h" +#include "dataset/util/auto_index.h" +#include "dataset/util/lock.h" +#include "dataset/util/memory_pool.h" +#include "dataset/util/path.h" +#include "dataset/util/service.h" +#include "dataset/util/slice.h" +#include "dataset/util/storage_container.h" + +using ListOfContainers = std::vector>; +namespace mindspore { +namespace dataset { +class StorageManager : public Service { + public: + using storage_index = AutoIndexObj>>; + using key_type = storage_index::key_type; + using value_type = storage_index::value_type; + + explicit StorageManager(const Path &); + + ~StorageManager() override; + + StorageManager(const StorageManager &) = delete; + + StorageManager &operator=(const StorageManager &) = delete; + + Status Write(key_type *out_key, const std::vector &buf); + + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; + + Status DoServiceStart() override; + + Status DoServiceStop() noexcept override; + + friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); + + private: + Path root_; + ListOfContainers containers_; + int file_id_; + RWLock rw_lock_; + storage_index index_; + + std::string GetBaseName(const std::string &prefix, int32_t file_id); + + std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); + + Status AddOneContainer(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/system_pool.h b/mindspore/ccsrc/dataset/util/system_pool.h index bd15ad11ddf..286e30a6158 100644 --- a/mindspore/ccsrc/dataset/util/system_pool.h +++ b/mindspore/ccsrc/dataset/util/system_pool.h @@ -19,8 +19,10 @@ #include #include #include +#include #include #include "./securec.h" +#include "dataset/util/allocator.h" #include "dataset/util/memory_pool.h" namespace mindspore { @@ -61,6 +63,11 @@ class SystemPool : public MemoryPool { uint64_t get_max_size() const override { return std::numeric_limits::max(); } int PercentFree() const override { return 100; } + + template + static Allocator GetAllocator() { + return Allocator(std::make_shared()); + } }; } // namespace dataset } // namespace mindspore From ef13a4b6fb7fe8705e09e0851687b67c20ca3100 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 25 Jun 2020 00:45:10 +0800 Subject: [PATCH 06/27] adjust cse code when op has side effect. --- mindspore/ccsrc/optimizer/cse.cc | 86 ++++++++++++------- mindspore/ccsrc/optimizer/cse.h | 2 +- .../pass/common_subexpression_elimination.cc | 2 +- .../pass/common_subexpression_elimination.h | 2 +- mindspore/ccsrc/pybind_api/export_flags.cc | 1 + mindspore/ccsrc/pybind_api/export_flags.h | 2 +- mindspore/nn/wrap/cell_wrapper.py | 4 +- mindspore/ops/operations/debug_ops.py | 2 +- tests/ut/python/utils/test_serialize.py | 1 + 9 files changed, 66 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc index 1af08ea3e12..0b675cca721 100644 --- a/mindspore/ccsrc/optimizer/cse.cc +++ b/mindspore/ccsrc/optimizer/cse.cc @@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } - +// The op like print, summary, or the op do not has true output, and always as a depend node input. +static bool HasSideEffect(const AnfNodePtr &node) { + auto prim = GetCNodePrimitive(node); + if (prim == nullptr) { + return false; + } + auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); + if (side_effect_v != nullptr && side_effect_v->isa()) { + return GetValue(side_effect_v); + } + return false; +} +// If true do not merge the node. bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { bool has_random_effect = false; auto prim_main = GetCNodePrimitive(main); auto prim_node = GetCNodePrimitive(node); - if (prim_main == prim_node) { - return false; - } + // if has random effect, when generate by different op (not same object), do not merge. if (prim_main != nullptr) { + if (prim_main == prim_node) { + return false; + } auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); if (effect_val != nullptr && effect_val->isa()) { has_random_effect = GetValue(effect_val); @@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons return has_random_effect; } -bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - bool replace = false; if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); } else if (main->isa() && node->isa()) { auto c_main = main->cast(); auto c_node = node->cast(); + // When appsame is true, check if has side effect, do not merge. + if (check_side_effect && HasSideEffect(main)) { + return false; + } const auto &inp1 = c_main->inputs(); const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - // Handle the case of two different Tensor, but with the same value - if (IsValueNode(inp1[j]) && IsValueNode(inp2[j])) { - auto tensor1 = GetValueNode(inp1[j]); - auto tensor2 = GetValueNode(inp2[j]); - if (tensor1->ValueEqual(*tensor2)) { - continue; - } - } - appsame = false; - break; - } - } - if (CheckRandomEffect(c_main, c_node)) { - appsame = false; - } - replace = appsame; + if (inp1.size() != inp2.size()) { + return false; } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + // Handle the case of two different Tensor, but with the same value + if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { + auto tensor1 = GetValueNode(inp1_j); + auto tensor2 = GetValueNode(inp2_j); + if (tensor1->ValueEqual(*tensor2)) { + continue; + } + } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { + // When the same side effect node as another two nodes' inputs, we still merge the node. + // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the + // node. + if (CheckReplace(inp1_j, inp2_j, false)) { + continue; + } + } + return false; + } + } + // When appsame is true, check if has random effect do not merge + if (CheckRandomEffect(c_main, c_node)) { + return false; + } + return true; } - return replace; + // a parameter node. + return false; } bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h index fd90f61eebc..57163cc5c9d 100644 --- a/mindspore/ccsrc/optimizer/cse.h +++ b/mindspore/ccsrc/optimizer/cse.h @@ -41,7 +41,7 @@ class CSE { return chg && report_changes_; } - virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; + virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc index 9af50eac330..297a167aa8e 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc @@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { } } // namespace -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h index 8e1768ea99c..18f433ab955 100644 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h +++ b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h @@ -31,7 +31,7 @@ class BackendCSE : public CSE { public: BackendCSE() = default; ~BackendCSE() override = default; - bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override; + bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 83392784f3e..253e271e525 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; +const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 74c27ff35d2..6ea584e66d8 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; - +extern const char GRAPH_FLAG_SIDE_EFFECT[]; } // namespace mindspore #endif // PYBIND_API_EXPORT_FLAGS_H_ diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index f0d920f51fa..9e3d00cc959 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -220,7 +220,9 @@ class DataWrapper(Cell): def __init__(self, network, dataset_types, dataset_shapes, queue_name): super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) - + # Also copy the flag in `network` construct + flags = getattr(network.__class__.construct, "_mindspore_flags", {}) + self.add_flags(**flags) self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) self.network = network diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index bafc72897e6..c62b3f1ab8b 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -334,7 +334,7 @@ class Print(PrimitiveWithInfer): @prim_attr_register def __init__(self): - pass + self.add_prim_attr("_side_effect", True) def __call__(self, *args): for arg in args: diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index c5b4586566e..035ea878459 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -336,6 +336,7 @@ class PrintNet(nn.Cell): def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_, scale1, scale2): self.print('============tensor int8:==============', int8) + self.print('============tensor int8:==============', int8) self.print('============tensor uint8:==============', uint8) self.print('============tensor int16:==============', int16) self.print('============tensor uint16:==============', uint16) From aabec55c79139580f4be7b7f274dd03e268fd6c2 Mon Sep 17 00:00:00 2001 From: Giancarlo Colmenares Date: Tue, 23 Jun 2020 09:25:47 -0400 Subject: [PATCH 07/27] Removing TransformFuncType --- mindspore/ccsrc/ir/optimizer_caller.h | 12 +- mindspore/ccsrc/optimizer/irpass.cc | 166 ++++++++++-------- .../optimizer/irpass/arithmetic_simplify.h | 59 +++---- .../ccsrc/optimizer/irpass/cast_eliminate.h | 6 +- .../optimizer/irpass/env_item_eliminate.h | 30 ++-- .../optimizer/irpass/incorporate_getitem.h | 27 +-- .../optimizer/irpass/item_tuple_eliminate.h | 33 ++-- .../ccsrc/optimizer/irpass/ref_eliminate.h | 4 +- .../optimizer/irpass/reshape_eliminate.h | 11 +- .../optimizer/irpass/special_op_eliminate.h | 36 ++-- mindspore/ccsrc/optimizer/opt.cc | 19 +- mindspore/ccsrc/optimizer/opt.h | 24 +-- tests/ut/cpp/optimizer/opt_test.cc | 8 +- 13 files changed, 227 insertions(+), 208 deletions(-) diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/ccsrc/ir/optimizer_caller.h index bd304541473..036f4ab5109 100644 --- a/mindspore/ccsrc/ir/optimizer_caller.h +++ b/mindspore/ccsrc/ir/optimizer_caller.h @@ -17,13 +17,23 @@ #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ +#include + #include "ir/anf.h" -#include "optimizer/opt.h" namespace mindspore { +namespace opt { +class Optimizer; +using OptimizerPtr = std::shared_ptr; +using OptimizerWeakPtr = std::weak_ptr; + +using PredicateFuncType = std::function; +} // namespace opt + class OptimizerCaller { public: virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } }; +using OptimizerCallerPtr = std::shared_ptr; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 0033e386d8a..0996abee2c2 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -14,140 +14,154 @@ * limitations under the License. */ -#include "optimizer/irpass.h" - #include -#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass.h" #include "optimizer/irpass/arithmetic_simplify.h" -#include "optimizer/irpass/special_op_eliminate.h" -#include "optimizer/irpass/item_tuple_eliminate.h" -#include "optimizer/irpass/env_item_eliminate.h" -#include "optimizer/irpass/tile_eliminate.h" -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass/reshape_eliminate.h" -#include "optimizer/irpass/transpose_eliminate.h" -#include "optimizer/irpass/reduce_eliminate.h" -#include "optimizer/irpass/partial_eliminate.h" -#include "optimizer/irpass/ref_eliminate.h" -#include "optimizer/irpass/merge_addn.h" #include "optimizer/irpass/branch_culling.h" -#include "optimizer/irpass/gradient_eliminate.h" -#include "optimizer/irpass/minmax_grad.h" -#include "optimizer/irpass/inline.h" +#include "optimizer/irpass/cast_eliminate.h" #include "optimizer/irpass/convert.h" -#include "optimizer/irpass/specialize_transform.h" -#include "optimizer/irpass/incorporate_getitem.h" -#include "optimizer/irpass/incorporate_call.h" +#include "optimizer/irpass/env_item_eliminate.h" #include "optimizer/irpass/grad_var_prepare.h" -#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/gradient_eliminate.h" +#include "optimizer/irpass/inline.h" +#include "optimizer/irpass/incorporate_call.h" +#include "optimizer/irpass/incorporate_getitem.h" +#include "optimizer/irpass/item_tuple_eliminate.h" #include "optimizer/irpass/mark_interface_fusion.h" +#include "optimizer/irpass/merge_addn.h" +#include "optimizer/irpass/minmax_grad.h" +#include "optimizer/irpass/param_replace.h" +#include "optimizer/irpass/partial_eliminate.h" +#include "optimizer/irpass/reduce_eliminate.h" +#include "optimizer/irpass/ref_eliminate.h" +#include "optimizer/irpass/reshape_eliminate.h" +#include "optimizer/irpass/special_op_eliminate.h" +#include "optimizer/irpass/specialize_transform.h" +#include "optimizer/irpass/symbol_resolver.h" +#include "optimizer/irpass/tile_eliminate.h" +#include "optimizer/irpass/transpose_eliminate.h" #include "optimizer/opt.h" namespace mindspore { namespace opt { namespace irpass { OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", + arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); - arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); + arithmetic_simplify2_ = + MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); special_op_eliminate_ = - MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", + MakeSubstitution(std::make_shared(), "special_op_eliminate", {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); - adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + zero_like_fill_zero_ = + MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); + adjust_all_reduce_mul_add_ = + MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate - item_tuple_eliminate_ = - MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); - tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); - cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); - reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); - transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); + item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); + cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); + reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); + transpose_eliminate_ = + MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); reduce_eliminate_ = MakeSubstitution( - ReduceOneEliminater(), "reduce_eliminate", + std::make_shared(), "reduce_eliminate", {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); - partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); - same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); - check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); - reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode); - depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); + partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); + same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = + MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); + reset_defer_inline_ = + MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); // Env Item Eliminate - env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); + env_get_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); + new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_ = - MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = - MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); // Ref eliminate - make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", + make_ref_eliminate_ = + MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", + get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - replace_refkey_by_param_ = - MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); - replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); + replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", + IsValueNode, opt::FORCE_RENORM); + replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); // Gradient transforms - expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); - minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); + expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); + minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); // branch culling - switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); - float_tuple_getitem_switch_ = - MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); + switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); + float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), + "float_tuple_getitem_switch", prim::kPrimTupleGetItem); float_env_getitem_switch_ = - MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); - convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); + MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + convert_switch_replacement_ = + MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); // Addn - merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); - addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); + merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); + addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); // inline - inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); - replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode); - specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); + inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + replace_applicator_ = + MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); + specialize_transform_ = + MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); // Incorporation incorporate_getitem_set_ = - MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); - incorporate_getitem_from_param_ = - MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); - incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); - incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); + MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); + incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), + "incorporate_getitem_from_param", IsCNodeGraphKernel); + incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); + incorporate_call_switch_ = + MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); // Virtual Dataset - virtual_dataset_eliminate_ = - MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), + "virtual_dataset_eliminate", prim::kPrimVirtualDataset); // Convert - print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); + print_tuple_wrapper_ = + MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); // Unused parameter eliminate unused_parameter_eliminate_ = - MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); - unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); + MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); + unused_output_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); // AddN eliminate - addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); + addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); // Mark interface fusion - mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); + mark_interface_fusion_ = + MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); } ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); + resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); } InferenceOptPrepareLib::InferenceOptPrepareLib() { - grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); + grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); } } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 270db8305f6..a26b81e9529 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -17,15 +17,16 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { FuncGraphPtr all_reduce_fg_{nullptr}; }; -class ArithmeticSimplify { +class ArithmeticSimplify : public OptimizerCaller { public: ArithmeticSimplify() - : multiply_by_zero_or_one_(), - tensor_multiply_by_one_(), - add_by_zero_(), - tensor_add_by_zero_(), - identity_(prim::kPrimIdentity), - opt_update_zero_tensor_(), - constant_duplicate_mul_(), - power_one_() { + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(add_by_zero_); @@ -761,10 +762,10 @@ class ArithmeticSimplify { } ~ArithmeticSimplify() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -773,15 +774,9 @@ class ArithmeticSimplify { } private: - MultiplyByZeroOrOne multiply_by_zero_or_one_; - TensorMultiplyByOne tensor_multiply_by_one_; - AddByZero add_by_zero_; - TensorAddByZero tensor_add_by_zero_; - PrimEliminater identity_; - OptUpdateZeroTensor opt_update_zero_tensor_; - ConstantDuplicateMul constant_duplicate_mul_; - PowerOneEliminate power_one_; - std::vector eliminaters_{}; + OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, + opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; + std::vector eliminaters_{}; }; // Arithmetic Simplifications should be done after step_parallel. @@ -789,15 +784,17 @@ class ArithmeticSimplify { // with shape(weight), but after step_parallel, shape of weight may be changed, so the // shape of the constant tensor should also be changed. So this pass is seperated from // ArithmeticSimplify and deferred until step_parallel. -class ArithmeticSimplify2 { +class ArithmeticSimplify2 : public OptimizerCaller { public: - ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } ~ArithmeticSimplify2() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { } private: - TensorMultiplyByZero tensor_multiply_by_zero_; - std::vector eliminaters_{}; + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h index 734d88cb10f..d98d0b677b3 100644 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#include "ir/visitor.h" #include "optimizer/irpass.h" #include "optimizer/optimizer.h" -#include "ir/visitor.h" namespace mindspore { namespace opt { @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, t_{nullptr}; }; -class CastEliminater { +class CastEliminater : public OptimizerCaller { public: CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} ~CastEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = cast_same_type_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h index 0f59c69fef8..3f100dcaec3 100644 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h @@ -17,18 +17,19 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#include -#include #include -#include #include +#include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" #include "utils/symbolic.h" namespace mindspore { @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { bool is_match_{false}; }; -class EnvGetItemEliminater { +class EnvGetItemEliminater : public OptimizerCaller { public: - EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { + EnvGetItemEliminater() + : new_env_get_item_(std::make_shared()), + add_env_get_item_(std::make_shared()), + env_get_set_item_(std::make_shared()) { eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(env_get_set_item_); } ~EnvGetItemEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -246,10 +250,8 @@ class EnvGetItemEliminater { } private: - NewEnvGetItem new_env_get_item_; - AddEnvGetItem add_env_get_item_; - EnvGetSetItem env_get_set_item_; - std::vector eliminaters_{}; + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + std::vector eliminaters_{}; }; // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h index 5afee45e95f..b6c8fb0e18e 100644 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h @@ -17,18 +17,20 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#include #include -#include #include +#include #include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" + namespace mindspore { namespace opt { namespace irpass { @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { internal::GetitemTransform getitem_transform_; }; -class IncorporateGetitemSet { +class IncorporateGetitemSet : public OptimizerCaller { public: - IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { + IncorporateGetitemSet() + : incorporate_getitem_(std::make_shared()), + incorporate_getitem_switch_(std::make_shared()) { eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_switch_); } ~IncorporateGetitemSet() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -403,9 +407,8 @@ class IncorporateGetitemSet { } private: - IncorporateGetitem incorporate_getitem_; - IncorporateGetitemSwitch incorporate_getitem_switch_; - std::vector eliminaters_{}; + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h index 21cdff51ad0..202951a2541 100644 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h @@ -17,13 +17,15 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#include #include +#include +#include -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" +#include "ir/optimizer_caller.h" #include "ir/visitor.h" #include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleEliminater { +class ItemTupleEliminater : public OptimizerCaller { public: ItemTupleEliminater() - : get_item_eliminater_(), - get_item_const_eliminater_(), - set_item_eliminater_(), - get_set_item_eliminater_(), - get_item_depend_reorder_() { + : get_item_eliminater_(std::make_shared()), + get_item_const_eliminater_(std::make_shared()), + set_item_eliminater_(std::make_shared()), + get_set_item_eliminater_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()) { eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(set_item_eliminater_); @@ -277,10 +279,10 @@ class ItemTupleEliminater { } ~ItemTupleEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -289,12 +291,9 @@ class ItemTupleEliminater { } private: - GetitemEliminater get_item_eliminater_; - GetitemConstEliminater get_item_const_eliminater_; - SetitemEliminater set_item_eliminater_; - GetSetitemEliminater get_set_item_eliminater_; - GetitemDependReorder get_item_depend_reorder_; - std::vector eliminaters_{}; + OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, + get_item_depend_reorder_; + std::vector eliminaters_{}; }; } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h index 41f379221c6..6d81b401c3c 100644 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h @@ -19,9 +19,9 @@ #include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h index fb43f6ffd8a..cafc8b796c4 100644 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h @@ -19,11 +19,12 @@ #include +#include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "operator/ops.h" #include "optimizer/irpass.h" #include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "operator/ops.h" #include "pipeline/static_analysis/dshape.h" namespace mindspore { @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}, shape_{nullptr}; }; -class ReshapeEliminater { +class ReshapeEliminater : public OptimizerCaller { public: ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} ~ReshapeEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { auto new_node = reshape_same_shape_eliminater_(optimizer, node); if (new_node != nullptr) { return new_node; diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index dcba80431ad..b6a4e1c8523 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -18,31 +18,31 @@ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #include -#include -#include #include +#include +#include -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" #include "ir/optimizer_caller.h" -#include "optimizer/irpass/prim_eliminate.h" +#include "ir/pattern_matcher.h" #include "ir/visitor.h" #include "operator/ops.h" -#include "ir/pattern_matcher.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -class SpecialOpEliminater { +class SpecialOpEliminater : public OptimizerCaller { public: SpecialOpEliminater() - : insert_gradient_of_(prim::kPrimInsertGradientOf), - stop_gradient_(prim::kPrimStopGradient), - hook_backward_(prim::kPrimHookBackward), - print_shape_type_(prim::kPrimPrintShapeType), - get_ref_value_(prim::kPrimGetRefValue), - mirror_(prim::kPrimMirror), - virtual_div_(prim::kPrimVirtualDiv) { + : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), + stop_gradient_(std::make_shared(prim::kPrimStopGradient)), + hook_backward_(std::make_shared(prim::kPrimHookBackward)), + print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), + get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), + mirror_(std::make_shared(prim::kPrimMirror)), + virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { eliminaters_.emplace_back(insert_gradient_of_); eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(hook_backward_); @@ -53,10 +53,10 @@ class SpecialOpEliminater { } ~SpecialOpEliminater() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { - new_node = eliminater(optimizer, node); + new_node = (*eliminater)(optimizer, node); if (new_node != nullptr) { return new_node; } @@ -65,9 +65,9 @@ class SpecialOpEliminater { } private: - PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_; - std::vector eliminaters_{}; + std::vector eliminaters_{}; }; // {PrimVirtualDataset, X} -> X diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 82fbcc2036b..4c2e85157f0 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -16,28 +16,27 @@ #include "optimizer/opt.h" +#include +#include #include #include -#include -#include #include "ir/anf.h" #include "ir/manager.h" -#include "utils/ordered_set.h" - -#include "utils/log_adapter.h" #include "optimizer/optimizer.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &renorm_action) { auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &renorm_action) { auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double t = GetTime(); #endif - AnfNodePtr result = transform_(optimizer, node); + AnfNodePtr result = (*transform_)(optimizer, node); #ifdef ENABLE_PROFILE if (optimizer != nullptr) { auto time = GetTime(); diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h index fb0bdc58be9..6601d969d28 100644 --- a/mindspore/ccsrc/optimizer/opt.h +++ b/mindspore/ccsrc/optimizer/opt.h @@ -17,24 +17,18 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#include -#include #include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" +#include "ir/optimizer_caller.h" #include "operator/ops.h" namespace mindspore { /* namespace to support opt */ namespace opt { -class Optimizer; - -using OptimizerPtr = std::shared_ptr; -using OptimizerWeakPtr = std::weak_ptr; - -using PredicateFuncType = std::function; -using TransformFuncType = std::function; // Define the interaction mode between an Optimize pass and Renormalize pass // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; class Substitution { public: - TransformFuncType transform_{nullptr}; + OptimizerCallerPtr transform_; std::string name_; PredicateFuncType predicate_{nullptr}; // an enum to mark this Substitution relation to renormalize pass RenormAction renorm_action_; - Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, + Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &renorm_action) : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} ~Substitution() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); }; using SubstitutionPtr = std::shared_ptr; -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const std::vector &prims, const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); class SubstitutionList { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 05e7e6b9788..2428d0dddb3 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { }; void SetUp() { - elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); - elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); - idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); - Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); + elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); + elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); + idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); + Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q); } bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { From 3d1ecaaeb58f18bdfebd56248f0d69c250cb1091 Mon Sep 17 00:00:00 2001 From: Danish Farid Date: Thu, 25 Jun 2020 03:19:14 -0400 Subject: [PATCH 08/27] updated UT test for Python (3) AugOps with BBox - MD5 checks + imrpv comments --- .../random_crop_and_resize_with_bbox_op.cc | 3 +- .../kernels/image/random_crop_with_bbox_op.cc | 3 +- .../random_vertical_flip_with_bbox_op.cc | 3 +- .../random_crop_with_bbox_01_c_result.npz | Bin 0 -> 1654 bytes ...dom_resized_crop_with_bbox_01_c_result.npz | Bin 0 -> 1654 bytes ...om_vertical_flip_with_bbox_01_c_result.npz | Bin 0 -> 1654 bytes .../test_random_crop_and_resize_with_bbox.py | 216 +++++------------- .../dataset/test_random_crop_with_bbox.py | 209 +++++------------ .../test_random_vertical_flip_with_bbox.py | 95 ++++++-- tests/ut/python/dataset/util.py | 31 +-- 10 files changed, 210 insertions(+), 350 deletions(-) create mode 100644 tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz create mode 100644 tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz create mode 100644 tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc index b820779ed1a..fbaf2c9326d 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc @@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow BOUNDING_BOX_CHECK(input); CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); - (*output).push_back(nullptr); // init memory for return vector - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // move boxes over to output size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc index 2be37f1da36..c873307afdd 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc @@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) int32_t padded_image_h; int32_t padded_image_w; - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); // since some boxes may be removed bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc index c6aa8450a8d..ffea851eac1 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc @@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); } - (*output).push_back(nullptr); - (*output).push_back(nullptr); + output->resize(2); (*output)[1] = std::move(input[1]); return VerticalFlip(input[0], &(*output)[0]); diff --git a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..0c220fd09d2f82888b93437e370325ab758da34d GIT binary patch literal 1654 zcmbW&dr(wW90%}w7Z6wzP*-^+uClpHZVRIF5=8>hys<0?$o0Wv9~YO$K7Mx(;tJ@Z zJS3nXKtv^JYMRoaQmo*gv@x8VCJGuSGwrV?njTixRKwWVJ%{UU`lEmP-PwER{_gpl z@9xf-drVwFpo*EVRm`o_!@j#g3`0JljEb=cLV1$GpjS7k|9>g{ z0XA2i)8jMP+yZZK*45%F@%vhQ%5s6Xxm{i?)Y-5SnE6pkmwo;zupHA0iX7Go0qqco zAV|ltoPtxd!2)Th9s4}I5`xVN@n;sKK?uiiY&(P^Q0lk}L3v7j@j`@Jf4WOD2ScYKv!+=bP z{C5C@AURa4iEt^%utVR$o84Wd_vp3dQPQ~~?7<_5wlibcj@EC?Qve6QwLy#bu z$EgX9+IRcYvZ5_k?MU9unk7lF7K2_kPY`23kSLiXYGU6FENbR1s5HlV-}KyAx(U`{ zNS4idVm2VyD47&$ULU)ZRq&2_a(GMwrN76OK`Mqc*`yPbfnbwlGO4kz)O=}pEo5d} z#_`;S;TBbkS&=fsd>3~{d*-D?p7nWmPKc#e}fzhxw6?tOdf*mlG#B` zx_PPhr?88;+Q#5#&q8NX=l_MB~wr3quW2_bzHSj^2X;3?|tW z5`z#FNv4>ZZ<1nXXM=rs@RuLnO21Zbf)WhRSb4KjTzwXo%Z)9#%I$((rC^qq;BGS0 z9t3-(nfB3{BJYGoj>kq!ef!ajFy3&<2BjFvStZIi!-4CL!)e(h&hblh-CcK>~Ow+a&){p;f=TL`Q4hvBxuHP zR5r(mX+h8`nKo)tGE)W$W2a^>{>hA}YNAg-JBH)3=^*A=1f7!UqQ>j^@YdPxw#_}} zR|X8)`y6y*cuqFY6VrpBS28b9qi>#^IM?sLj)$&QU;V=GhCU4avKb)e1cE`yoTR4a z{P*WS&xm)Xt#0Zci@DweLl{n3c^<=QtHH#v3RS>=Z5860lgdxtj{j^NCN5|JX|Upv JDqa)l-QR!@8nyrc literal 0 HcmV?d00001 diff --git a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..a909cbe88c5719a72a75f24017e303be51301f54 GIT binary patch literal 1654 zcmbW&dr(wW90%}wdCG?I2vK?HqQEYS2q;NN0fLs?faSsfg4x)ucgn&*6BRX6&DSclO@7zk5FC zySsDd9)sZR6UHr849A~6t(o)XIPwYL!Z@=mS0*Rf8$3K0Oq*mNc}VJbVQGP3H_y4b zMx9M`T4h~^R@YRdOVsMB9Ws{9_DY9r6G?lAxy~sneWzq@5S4pRolXM)|&GkET)0tNlG#i6O#rKO~ZbL?MV~A*-bD2bt9vw5o|A zW(|V1%*0X?m^U3aGF~-Vf8p$Fx8@d15QiaNH9BJS2ojiCM@Gnv$+84hzDp1`qBOMY*;n!WxiJdWWB)npNq zjUb1aZPZkJ5Hr|u=gY3~+H00kbKBRDiy==n+lkqM;7MklqUO$Q&}>!78QY}3ZCB+? zF@l{K@>L@c^E842W(ujv?zsL7dW|5;3r zEd4HrpI-aZfZ% zwdd}>JaYQ2w`;+KVV{~_NlX<2Gc%S48!HBzYD8kH5$tD1qUQYNerI&arq~j@3@_Z%9or+#N4;+S_~wnJqu{{Mpc)`XM&M)yQ*$XlGCu!M z>tLf)KH^SU*$OTUZq*zh=2-+DW}c%auJ3x;k%eSONl&~hFT@=SjTo9#(@abYf`iPo zQqw-*xiDsuLbgXVkA6GdSq^O&4ymS{m<|M;%ydz6vFBve-uzbsJ_`yx@0q?g3Edca zRMSh$^9cHwIZVyOo5s5Fb7?;Aj_Fs%fBDe_M=-Ioc+F`Q7%5HTkayui$h)TB*bLCZo`Uw%Jg`1W#XHV}p}oU(`_hL)ZZomfZdfN--)!SctfOUXl zFbX(uDg$9APG*SFKb$%DqQsc&56{b3# z>*Tt1cFALN>vA-@o=RPcM%O62vD@mj$Zoqt+C|njkL2%r99Ea)zt^XyXwnl?G`*Vt zU#h-9o3o?cYdRF_#8!v@_ywE1tU%t$P7Q+LI(GjCZkjP9DH6Iy$R-f-?X4=N; z(1Z+gE+k_}QOssy3>2@SGS~l=gd4YyR_ED8R5&F@?kwAt+|1gqnQs>T^Zst76}WnUhgx zzIH(=hBCzn#FQhbV8%#|7JomW9Vq)Q#;!fJGXDBkP>I2$m?~lrf@)?SrDh>K?u*Sw zFGTJTnkI_PubzNi78j)=ll702zZzF+hwPfrlAPP2u6W{2PWT zIqb_i-)m8ebI^gIQ!!n{>_^~b<^VMbt-qS(Okdx7^Hnp^B@>I#jiE;|y~I3*;2<-7 z)Wp>+sNz59ujroc>I~g^WCI++(65*QVh$r1WX4Ag-#>i)himo6jhA1UyIpgs9)>Us zD`te4#}OQ1W|W!>Cnm3KnYYijl+0%?yjgS+#xOjgn4`oziQp+_o~C9`!P`}Q%jc(L z=jm@tlNYbSIEG`2nIL8o!86QEQIk@gI(B<&i1^!|r432oIs?-fj*F6n;aSmO68Hdh g;D2ou{x!#{kGviK**Hu>a0qGe{v*qOO{8~!10tFi!T plot_rows: - orig = np.split(orig[:comp_set*plot_rows], comp_set) + [orig[comp_set*plot_rows:]] - aug = np.split(aug[:comp_set*plot_rows], comp_set) + [aug[comp_set*plot_rows:]] + # Create batches of required size and add remainder to last batch + orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added + aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) else: orig = [orig] aug = [aug] for ix, allData in enumerate(zip(orig, aug)): - base_ix = ix * plot_rows # will signal what base level we're on + base_ix = ix * plot_rows # current batch starting index + curPlot = len(allData[0]) - sub_plot_count = 2 if (len(allData[0]) < 2) else len(allData[0]) # if 1 image remains, create subplot for 2 to simplify axis selection - fig, axs = plt.subplots(sub_plot_count, 2) + fig, axs = plt.subplots(curPlot, 2) fig.tight_layout(pad=1.5) for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): cur_ix = base_ix + x + (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row + + axA.imshow(dataA["image"]) + add_bounding_boxes(axA, dataA["annotation"]) + axA.title.set_text("Original" + str(cur_ix+1)) + + axB.imshow(dataB["image"]) + add_bounding_boxes(axB, dataB["annotation"]) + axB.title.set_text("Augmented" + str(cur_ix+1)) - axs[x, 0].imshow(dataA["image"]) - add_bounding_boxes(axs[x, 0], dataA["annotation"]) - axs[x, 0].title.set_text("Original" + str(cur_ix+1)) logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) - - axs[x, 1].imshow(dataB["image"]) - add_bounding_boxes(axs[x, 1], dataB["annotation"]) - axs[x, 1].title.set_text("Augmented" + str(cur_ix+1)) logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB["annotation"])) plt.show() From 0f58f0338e61469b7ef2c08be87847a5ce40eb82 Mon Sep 17 00:00:00 2001 From: islam_amin Date: Wed, 24 Jun 2020 15:36:12 -0400 Subject: [PATCH 09/27] updating ut for RandomHorizontalFlipWithBBox and BBoxAugment --- .../bounding_box_augment_crop_c_result.npz | Bin 0 -> 1654 bytes ...bounding_box_augment_rotation_c_result.npz | Bin 0 -> 1654 bytes ...unding_box_augment_valid_edge_c_result.npz | Bin 0 -> 1654 bytes ...nding_box_augment_valid_ratio_c_result.npz | Bin 0 -> 1654 bytes ..._horizontal_flip_with_bbox_01_c_result.npz | Bin 0 -> 1654 bytes .../dataset/test_bounding_box_augment.py | 370 +++++++++--------- .../test_random_horizontal_flip_bbox.py | 266 ------------- .../test_random_horizontal_flip_with_bbox.py | 229 +++++++++++ 8 files changed, 411 insertions(+), 454 deletions(-) create mode 100644 tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz create mode 100644 tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz create mode 100644 tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz create mode 100644 tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz create mode 100644 tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz delete mode 100644 tests/ut/python/dataset/test_random_horizontal_flip_bbox.py create mode 100644 tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..e4e92210d7a4b8d01498673fad591c90deed965c GIT binary patch literal 1654 zcmbW&eM}Q)90%~bmZB7=C@;c>vusFEPu9F>B| z1Vo&3<7L7ezM{zzmk3!lqR2wbW^looEiiQR4>RT%!zGgt6SKK{uIDXFwtx1!^e*?i z=kt8q=E*$DqQl0vOJVOah z?V;+`7S?5Us?$_zPrf=JA6RfFpP zm$FH2=4xvkZmrqwWVJQ5)!4>=U!z-D=48!w&V`+|W~>CtJ4(q}yib8PD>PQhDntfY`Mi|~R*^iP76`O@1Im*i$QuxhAmmBF zCRqK?0188>sXZfLBxRz0ZONrRL+90#n@5WRAq>OQrY9lKkdSZ$Ys8Rets%VW8n_p0 z%JKMoHk_yFdXyasYcZ%K6G6;#2%Z;Bq+qr^X#8?<;oP#o9px#^L{|x{!w@AIH8C0l z(V|%|nB$6uuOjK(jk*C%gs1cR6vSYNl}sElS_JW;Nf69l`=IO(V}FtEc3XVZ@TGut;1w&KWx zx^)?zdwxf+CBbG4TO^ZC%*zP2ie{T&A_9-o7aGD(XI-lpJb96dhYSp_NG6jQ9fIwm z$r8-+{bRWSm5RAM|HAc#q2?~g#*iZ!hL~Igd7{w^W^s1+jj5gHkA5Hisi4&{G6VS- z43c@37=++8(d-b+1;1hU{mW-V;>L!8Lb0?2 zOUO(+5$qCYdR>^QYRn$PT~Vak#;UH2gnUOsDTX(s^fF?0BQT1_^zX)u!6F%!m~sRa zqOl6*^TLI}+``EI%(Q{%mhV8c);nJQwc5!8r=6U?yweEls&-G+t-%qOU_=hfC zm`*q|YWNZAF?b}?KujZoeWGddnxe*2H@2L+JM}}>&z;wU6JbAwX34xo%mD-~qG=UO zGxx>(_Op&(`+FV-cN9&ZhBgcbCDTsKAp{+wIV_l)CwlHCPg({lvnG>neX5&=w=o=% z%sa#!Mewd@-V@A+`BUg{$K0bOx2OMHbExbgbYkd|OgAw-2;LXXF~Jo6(tFkY?aHtj z<>b)Khf$-@i=ofNvKT%vX$=f53zGlWR>7}1v6{%+@$Zepz$ko4gXWJE{u(XZ{RNrK BCVBt> literal 0 HcmV?d00001 diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..8cc7e15e31b116565a79f65dbef223ce5d920953 GIT binary patch literal 1654 zcmbW&dr(wW90%~byF3>mz+D7NSw-DNa0LQ+BoZNyyt=HXfTcv*$Hhh9?*5ju#0DxU z4|(*&NNf@J}8_Kn<27+u*Wng6I0pv2Q^eP*gbcTw`r#S>33)Eo%_4z zbH2MfXYMhw3Lh2Yt(A=WRC(|(z6?V?fsBeVyId8Cdfw@gF`;6U3?vUp^%#n`8_Q%& zJ=3VQadlRgHdCW*D$*usw3QAQcA5DKhs(y1cD}i$juZNIcC(Wc?sdrtn&h|y&0fv_ zFJ-gB%DZcw9=+A!;`Fud8eAoOU$aM9;o__gz7D(GR;&cZJ4(shyiY;KGOZxbm03X1 z4n7Ecb*zk)v+_3ZqeE@jrCYTsj4;a<{+dT9@5zvyBHa27xKaogp)jj z6D<2@0D~dY(vcgwWua5g8y8y!`-}PywGJJIC=Ah-2O$rWkVg*P zd`8g3RQS6sAJyN37z`T8tR!X?f>>%+izZ_*W<%`h56-pBEx9v%bgT~6U|1^|EipO- zan!66P1^N{tC1(O0&KUxh;k?fVjvzvf@IbcqerlTnncmerGz&5hs`_lW>xu%eHV^H z5{6{SJW5Oof>dfYie}Li*}8Bs5|DHCKJJ9)NB&XuWqhk4wS3L ztqp5phpx#X3&Uo~WE1l^f*fj|5Y6xHE57f@`hy+riQC9!caK9ZhAooGBPJg~0X2o9 zX^7^lYmz=&$L8hcAD{Rn9G=9mRWd9w+YoH0#vq!R(Ovgqi-$UXuyux94^!=cA`C{! z>>vgqc#4{*Me})b(v9GEZVpGkcQxnW?Z7iojG@HBnUuonGxNLyn{kzchthH|NlUPd z%=8R`XX#8k#hI!{9P#`m-}I`H>PsI-d?kZ&47;TC3Syo^V5Y|M@5YM3CK-;HN(8&9 zv5V$P$^84(4sTc0|6Vyy2ssKT&EGS$S?AgHB=7tM!;vyHcW8L5Yvm`O=^yarsT)IPd#6TMa?QF z&rd9@y$pRAj#@Yl!&?@;k(J3+ivQXwgf&O2nY(Lff7T+P_UrInBbj6eid$M#iRz#s@1NvU@Ja+ceQjzn!}~_jk|d zeD9w#d#fdLe6lcGD}=Oh=SROJ2m;#{2+4w1RU2}gN=vIrNYa~ZAUnwFxVvUsb)8Ar zCmgUf%2B^+TW__sSJ|?ywx)s>b|OgfKfwR2pZG!Xybjh$}PsVD3dc zL4sW}NpVu#AtdsljkL8z7BR0jPP@#i0?d~L$<&Jl1fpFs*P66$JMM71-Icer(!Zn< ziaULM9^7TGW%nl-_BP8{=)od;KnUoONdYZyF~Jg+7vF=s1G525c}SiONFi8yJD>=u z{|pc)EWW<7q{O3FLdBMXuP1gcyk0i)=`EyDEc4wCNoOJV5M=O>djm`7GRrBfhFQVP zN`m{iS*4qz=+yV#0n5a&=ve+%PFf=Fr^qynjTt*Z7B{PPGhDQxxxVv~rL1I)F#64~ z3l55G!#u!@lVA-uIl7rX^Ht>+RT*`Ee%yI3;cDAfJV=pin6=F05j@1rI^8T@*+5Xr z%|_k0PIWIiEP6gXt3C4Hf9`5S8AZ8aDwuH*Y~p6KZnm6THt_c8&Y80o<=XyW=4EW5 zs5Fek%vOSJ+_-f^rxn+=-cP1>*W9>RL64n86-BjSwlhNrc5t&(Hy3_eH?`Nk=gs&2 zNGZN~KGTmHibs8Nt*EU&ucAb#mj)vW9`m5qSb}wIrd@&=h%WMqPJ*Yn>C#Q; zAXQJT7VFD5y`Hpk?8sFdq8xo;NdM`>C|tg(DO%7^a_@0fM949MjFnyRV(;C_1lLP7Dr~ z554~Zj#CU8<^(e@5}f4bCEfgM9lT62ElAn^`3@#lU<03MN8ol7ZwQsX}6VYF0R?>(`9RQZDw7PR#z+du+Prh1fPo|?E-saGbi?&-F7c0-s@A7w5bV6+JoBv zU#fPclW%GA`VCIO#~GSh8nH+GzIMOL=Hr|K-;8}NPOJhZFiOR{0#8A~GTjiAr*MF> z7lILl=vf6D#0K?1C>`p;elMqjHP#^UXPUEMEz7WqUI;^=(z8mdLhQCdxXENHUKL9^ zu|^CJIr^&MVZD|7Kb|Ig9S4t8!=rjPL++yB31d${!xErF}fq>dHPzM4u2%_%? z;r0 zHgUvkM(`vxTO`w&r0W1j-L6REf!k-yv!k#TL%eKs#OM(uQ1g^zZVhfKEK|qlRHnyP zEi9XtAQ3~7Y?6sFAlODtie#n}2IfjX2|KGE)P#ElCKI+}NR`dg#H1lerzS%(6Z-m# zN8FP&)~}X|zDeFZ0ht)GWb+I$*$AGcCPy;$Q|<7_FB!L*R?@ncqN6Xs4h%bGVFAvHyk$@NEFEtoFdFa_hMH5cc- z@B)Tn*|5ayMo>bHNitQo!aaFxUDxFKyh|m+kw!3Mu*jyA7=++OYF?7eU6u)dZ|e@b za`ZxIW@Pw2D8o?h;H)Zf_1Sq|z;^5rc&MlbtGom&$xJUJ*h6Qkl4feSCM5DxAz7Yl z4O5q+zg0jrhF9ct8!~owsTr0``EMiB{;$`qzp0x0@cO;@D=>oLZ3o9;7=f^_#cz-b>n literal 0 HcmV?d00001 diff --git a/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..d360bb98ec7b551d207f73903b519123b5244939 GIT binary patch literal 1654 zcmbW&drVtZ90%~*mNM7r#>#Udql}Ff!44R&F$PW;OfIzDnFu&o>ElAn!{l&ynJ z9OE@mm^g-yIVCLM_J>9j4Piqw$&6VvF(#Wpe2l1s=r-dYKC(FMIh+(|dD& z_k7N`ZBFhn@rs}@Zl&V5{vQhUx0D=5KC8Jf&f@WyQ}m+SFXKXJlMEyeNevmw%S=@= zu7zvW+6Ax8qs`H1+l<;|jkeC^!5)ifc6saqX%|_Vyh5Pwby(a&;9i%OtVv5u)*R9N z|5CLpY@)B(?bq8}9zoyiYr@XJ_qF&{W{+TViC*mS*{}+@l~F3uzVZ}gJl6^G0+|&Q z-4KL8spDn5oR@b&FdJ&eez%~4Rh9C<&opMk13bsex?wc}m5x_b$^zXsSYt34N|pmf zPkcUx2d%v|@Q|*O{6EfSOT7SVYhazu!8vGTh$E1<9)X(V1$9EGV-W-~Rp-oToR{oZFKSB)`Mcc+d{D(4TELMnzd$!sGg9YF>&nbfRb zm$duiE741ZzhrG&n1LIRg&|uq+lk3Ru!EUgYMQ@}UifvNV<2t&=E!e4ZimB840)2t zC+2YkyO?=`noB)j&Y6BxeYbUf*_j{LkOBo53MI3fm?8wl%sKIbRN;eZzi@?H+_5Q|&!7dqrm^uXY%s8kKK7q)!E%uC3JbC8%r2>i?(rbc_QZm}hG%ekv*UJcJ-I3}46VvZx|WTuOnW8#&Y#bfTDUmso$@2i*{g>DQz zlIbPp1cE+h`lnE_%35u9Y^MQY9%ucQ9HpYE)h zpS>G%!u$seVR%V0r-&IwaGIGBYHaVQ8`WcfmF~PSkkNDU)LD2L!z)%n!0@V7Z{lV0 hFvWju6@fL!s-3(Y|JgW9yfT Date: Fri, 26 Jun 2020 11:47:40 +0800 Subject: [PATCH 10/27] Make assign-node to be before jump-node, ensure child graph can get its input Signed-off-by: zhoufeng --- .../ccsrc/session/ascend_control_parser.cc | 38 +++++++++++++++---- .../ccsrc/session/ascend_control_parser.h | 5 ++- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 868b968d9e4..573c1c1d356 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; namespace mindspore { namespace session { +static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { + auto &nodes = parent_graph->execution_order(); + for (auto &node : nodes) { + if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { + return node; + } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && + (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || + child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { + return node; + } + } + MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); + return nullptr; +} + static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, const NotNull *> memo) { if (memo->find(kg.get()) != memo->end()) { @@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::mapsecond), NOT_NULL(arg), NOT_NULL(parameter)); + InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), + NOT_NULL(parameter)); } } } @@ -433,7 +449,8 @@ std::tuple AscendControlParser::ParsePartial(NotNull kg, NotNull from, +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, + NotNull to_graph, NotNull from, NotNull to) { std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); @@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull kg << to_outputs.size() << "]"; } for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + if (assign_node != nullptr) { + auto jump_node = GetJumpNode(from_graph, to_graph); + if (jump_node != nullptr) { + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + } + } } } -void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { +AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return; + return nullptr; } if (from.get() == to.get()) { - return; + return nullptr; } MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " << to->DebugString(); @@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul assign_node->set_abstract(to->abstract()); // append the assign at the end of from graph InsertDependToGraph(kg, NOT_NULL(assign_node)); + return assign_node; } std::vector AscendControlParser::RecurseGraph(NotNull graph, diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 73d68449b31..0cf7069046d 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,8 +52,9 @@ class AscendControlParser { const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); - static void InsertMultipleAssignToGraph(NotNull kg, NotNull from, NotNull to); - static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, + NotNull from, NotNull to); + static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); // root graph order static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, From 617eb5510a426d62367278e8a28a55606421f83d Mon Sep 17 00:00:00 2001 From: dayschan <6573942+dayschan@user.noreply.gitee.com> Date: Tue, 23 Jun 2020 19:20:58 +0800 Subject: [PATCH 11/27] GraphKernel support akg batchmatmul --- akg | 2 +- mindspore/ccsrc/kernel/kernel_query.cc | 7 +++ mindspore/ops/_op_impl/akg/__init__.py | 1 + mindspore/ops/_op_impl/akg/batchmatmul.py | 73 +++++++++++++++++++++++ 4 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/akg/batchmatmul.py diff --git a/akg b/akg index c460176523d..df57a6cf945 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit c460176523d039c8995f1d71089753725ebc0792 +Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35 diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 5eda8479170..4a8ae81afa4 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -23,6 +23,7 @@ #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "kernel/akg/akg_kernel_metadata.h" #include "session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace kernel { @@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vectorenable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + kernel_type = KernelType::AKG_KERNEL; + } + switch (kernel_type) { case KernelType::AKG_KERNEL: AkgMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ops/_op_impl/akg/__init__.py b/mindspore/ops/_op_impl/akg/__init__.py index f38b99f5e4f..fd86dbf9991 100644 --- a/mindspore/ops/_op_impl/akg/__init__.py +++ b/mindspore/ops/_op_impl/akg/__init__.py @@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg from .less import _less_akg from .log import _log_akg from .matmul import _matmul_akg +from .batchmatmul import _batchmatmul_akg from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg from .max_pool_with_argmax import _max_pool_with_argmax_akg from .max import _max_akg diff --git a/mindspore/ops/_op_impl/akg/batchmatmul.py b/mindspore/ops/_op_impl/akg/batchmatmul.py new file mode 100644 index 00000000000..f5da71aa25e --- /dev/null +++ b/mindspore/ops/_op_impl/akg/batchmatmul.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchMatMul op""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "BatchMatMul", + "imply_type": "AutoDiff", + "fusion_type": "OPAQUE", + "attr": [ + { + "name": "transpose_a", + "param_type": "optional", + "type": "bool" + }, + { + "name": "transpose_b", + "param_type": "optional", + "type": "bool" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x1" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x2" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "output" + } + ] +}""") +def _batchmatmul_akg(): + """BatchMatMul AKG register""" + return From 81bf4bde1d09d13463efe499f4fa81a495350ff6 Mon Sep 17 00:00:00 2001 From: Jesse Lee Date: Fri, 26 Jun 2020 12:22:44 -0400 Subject: [PATCH 12/27] AutoIndexObj primary should start with 0 --- mindspore/ccsrc/dataset/util/auto_index.h | 2 +- tests/ut/cpp/dataset/btree_test.cc | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h index 11a2e90b00d..5c43ecfd80b 100644 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ b/mindspore/ccsrc/dataset/util/auto_index.h @@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree { } private: - static constexpr key_type kMinKey = 1; + static constexpr key_type kMinKey = 0; std::atomic inx_; }; } // namespace dataset diff --git a/tests/ut/cpp/dataset/btree_test.cc b/tests/ut/cpp/dataset/btree_test.cc index 168f550f349..75d5133e58a 100644 --- a/tests/ut/cpp/dataset/btree_test.cc +++ b/tests/ut/cpp/dataset/btree_test.cc @@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) { EXPECT_TRUE(rc.IsOk()); uint64_t min = ai.min_key(); uint64_t max = ai.max_key(); - EXPECT_EQ(min, 1); - EXPECT_EQ(max, 4); - auto r = ai.Search(3); + EXPECT_EQ(min, 0); + EXPECT_EQ(max, 3); + auto r = ai.Search(2); auto &it = r.first; EXPECT_EQ(it.value(), "b"); MS_LOG(INFO) << "Dump all the values using [] operator."; From 277aba5326b03763579996ac3629cdcfb21be62b Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Thu, 25 Jun 2020 21:41:42 -0400 Subject: [PATCH 13/27] dataset: Fixup docs; remove pylint disabled messages in UT --- mindspore/dataset/engine/datasets.py | 6 +- .../dataset/transforms/vision/c_transforms.py | 6 +- tests/ut/data/dataset/declient.cfg | 3 +- tests/ut/python/dataset/test_batch.py | 8 +-- tests/ut/python/dataset/test_center_crop.py | 11 +--- tests/ut/python/dataset/test_config.py | 7 ++- tests/ut/python/dataset/test_filterop.py | 57 +++++-------------- tests/ut/python/dataset/test_pad.py | 14 ++--- 8 files changed, 39 insertions(+), 73 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ca6f7ca33e5..360cdb1860e 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1040,7 +1040,7 @@ class Dataset: Args: columns (list[str], optional): List of columns to be used to specify the order of columns - (defaults=None, means all columns). + (default=None, means all columns). Returns: Iterator, list of ndarray. @@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset): class_indexing (dict, optional): A str-to-int mapping from label name to index (default=None, the folder names will be sorted alphabetically and each class will be given a unique index starting from 0). - decode (bool, optional): decode the images after reading (defaults=False). + decode (bool, optional): decode the images after reading (default=False). num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This @@ -4760,7 +4760,7 @@ class _NumpySlicesDataset: def process_dict(self, input_data): """ - Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first. + Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. """ # Convert pandas like dict(has "values" column) into General dict data_keys = list(input_data.keys()) diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index aef714953f0..3fdf7795d0f 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp): Flip the input image horizontally, randomly with a given probability. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob @@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp): Maintains data integrity by also flipping bounding boxes in an object detection pipeline. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob @@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp): Flip the input image vertically, randomly with a given probability. Args: - prob (float): Probability of the image being flipped (default=0.5). + prob (float, optional): Probability of the image being flipped (default=0.5). """ @check_prob diff --git a/tests/ut/data/dataset/declient.cfg b/tests/ut/data/dataset/declient.cfg index b657ead6d5f..e09b24812ad 100644 --- a/tests/ut/data/dataset/declient.cfg +++ b/tests/ut/data/dataset/declient.cfg @@ -4,6 +4,7 @@ "numParallelWorkers": 4, "workerConnectorSize": 16, "opConnectorSize": 16, - "seed": 5489 + "seed": 5489, + "monitor_sampling_interval": 15 } diff --git a/tests/ut/python/dataset/test_batch.py b/tests/ut/python/dataset/test_batch.py index 07eba394f19..9b9baeec33e 100644 --- a/tests/ut/python/dataset/test_batch.py +++ b/tests/ut/python/dataset/test_batch.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from util import save_and_check - import mindspore.dataset as ds from mindspore import log as logger +from util import save_and_check # Note: Number of rows in test.data dataset: 12 DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] @@ -434,7 +433,6 @@ def test_batch_exception_11(): assert "drop_remainder" in str(e) -# pylint: disable=redundant-keyword-arg def test_batch_exception_12(): """ Test batch exception: wrong input order, drop_remainder wrongly used as batch_size @@ -447,12 +445,12 @@ def test_batch_exception_12(): # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR) try: - data1 = data1.batch(drop_remainder, batch_size=batch_size) + data1 = data1.batch(drop_remainder, batch_size) sum([1 for _ in data1]) except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "batch_size" in str(e) + assert "drop_remainder" in str(e) def test_batch_exception_13(): diff --git a/tests/ut/python/dataset/test_center_crop.py b/tests/ut/python/dataset/test_center_crop.py index d4f8735fb0b..6dfa9fc7c30 100644 --- a/tests/ut/python/dataset/test_center_crop.py +++ b/tests/ut/python/dataset/test_center_crop.py @@ -109,23 +109,18 @@ def test_center_crop_comp(height=375, width=375, plot=False): visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2) -# pylint: disable=unnecessary-lambda def test_crop_grayscale(height=375, width=375): """ Test that centercrop works with pad and grayscale images """ - def channel_swap(image): - """ - Py func hack for our pytransforms to work with c transforms - """ - return (image.transpose(1, 2, 0) * 255).astype(np.uint8) - + # Note: image.transpose performs channel swap to allow py transforms to + # work with c transforms transforms = [ py_vision.Decode(), py_vision.Grayscale(1), py_vision.ToTensor(), - (lambda image: channel_swap(image)) + (lambda image: (image.transpose(1, 2, 0) * 255).astype(np.uint8)) ] transform = py_vision.ComposeOp(transforms) diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index c4d665b3917..59be886c23d 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -37,6 +37,7 @@ def test_basic(): num_parallel_workers_original = ds.config.get_num_parallel_workers() prefetch_size_original = ds.config.get_prefetch_size() seed_original = ds.config.get_seed() + monitor_sampling_interval_original = ds.config.get_monitor_sampling_interval() ds.config.load('../data/dataset/declient.cfg') @@ -45,23 +46,27 @@ def test_basic(): # assert ds.config.get_worker_connector_size() == 16 assert ds.config.get_prefetch_size() == 16 assert ds.config.get_seed() == 5489 + # assert ds.config.get_monitor_sampling_interval() == 15 # ds.config.set_rows_per_buffer(1) ds.config.set_num_parallel_workers(2) # ds.config.set_worker_connector_size(3) ds.config.set_prefetch_size(4) ds.config.set_seed(5) + ds.config.set_monitor_sampling_interval(45) # assert ds.config.get_rows_per_buffer() == 1 assert ds.config.get_num_parallel_workers() == 2 # assert ds.config.get_worker_connector_size() == 3 assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 + assert ds.config.get_monitor_sampling_interval() == 45 # Restore original configuration values ds.config.set_num_parallel_workers(num_parallel_workers_original) ds.config.set_prefetch_size(prefetch_size_original) ds.config.set_seed(seed_original) + ds.config.set_monitor_sampling_interval(monitor_sampling_interval_original) def test_get_seed(): @@ -150,7 +155,7 @@ def test_deterministic_run_fail(): def test_deterministic_run_pass(): """ - Test deterministic run with with setting the seed + Test deterministic run with setting the seed """ logger.info("test_deterministic_run_pass") diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py index 015d5803799..876278571de 100644 --- a/tests/ut/python/dataset/test_filterop.py +++ b/tests/ut/python/dataset/test_filterop.py @@ -50,9 +50,7 @@ def test_diff_predicate_func(): def filter_func_ge(data): - if data > 10: - return False - return True + return data <= 10 def generator_1d(): @@ -108,15 +106,11 @@ def test_filter_by_generator_with_repeat_after(): def filter_func_batch(data): - if data[0] > 8: - return False - return True + return data[0] <= 8 def filter_func_batch_after(data): - if data > 20: - return False - return True + return data <= 20 # test with batchOp before @@ -152,9 +146,7 @@ def test_filter_by_generator_with_batch_after(): def filter_func_shuffle(data): - if data > 20: - return False - return True + return data <= 20 # test with batchOp before @@ -169,9 +161,7 @@ def test_filter_by_generator_with_shuffle(): def filter_func_shuffle_after(data): - if data > 20: - return False - return True + return data <= 20 # test with batchOp after @@ -197,15 +187,11 @@ def generator_1d_zip2(): def filter_func_zip(data1, data2): _ = data2 - if data1 > 20: - return False - return True + return data1 <= 20 def filter_func_zip_after(data1): - if data1 > 20: - return False - return True + return data1 <= 20 # test with zipOp before @@ -247,16 +233,11 @@ def test_filter_by_generator_with_zip_after(): def filter_func_map(col1, col2): _ = col2 - if col1[0] > 8: - return True - return False + return col1[0] > 8 -# pylint: disable=simplifiable-if-statement def filter_func_map_part(col1): - if col1 < 3: - return True - return False + return col1 < 3 def filter_func_map_all(col1, col2): @@ -311,9 +292,7 @@ def test_filter_by_generator_with_map_part_col(): def filter_func_rename(data): - if data > 8: - return True - return False + return data > 8 # test with rename before @@ -334,15 +313,11 @@ def test_filter_by_generator_with_rename(): # test input_column def filter_func_input_column1(col1, col2): _ = col2 - if col1[0] < 8: - return True - return False + return col1[0] < 8 def filter_func_input_column2(col1): - if col1[0] < 8: - return True - return False + return col1[0] < 8 def filter_func_input_column3(col1): @@ -439,9 +414,7 @@ def test_filter_by_generator_Partial2(): def filter_func_Partial(col1, col2): _ = col2 - if col1[0] % 3 == 0: - return True - return False + return col1[0] % 3 == 0 def generator_big(maxid=20): @@ -461,9 +434,7 @@ def test_filter_by_generator_Partial(): def filter_func_cifar(col1, col2): _ = col1 - if col2 % 3 == 0: - return True - return False + return col2 % 3 == 0 # test with cifar10 diff --git a/tests/ut/python/dataset/test_pad.py b/tests/ut/python/dataset/test_pad.py index 1b3882cd54b..7b66b6b36b5 100644 --- a/tests/ut/python/dataset/test_pad.py +++ b/tests/ut/python/dataset/test_pad.py @@ -16,12 +16,12 @@ Testing Pad op in DE """ import numpy as np -from util import diff_mse import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision import mindspore.dataset.transforms.vision.py_transforms as py_vision from mindspore import log as logger +from util import diff_mse DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" @@ -69,23 +69,19 @@ def test_pad_op(): assert mse < 0.01 -# pylint: disable=unnecessary-lambda + def test_pad_grayscale(): """ Tests that the pad works for grayscale images """ - def channel_swap(image): - """ - Py func hack for our pytransforms to work with c transforms - """ - return (image.transpose(1, 2, 0) * 255).astype(np.uint8) - + # Note: image.transpose performs channel swap to allow py transforms to + # work with c transforms transforms = [ py_vision.Decode(), py_vision.Grayscale(1), py_vision.ToTensor(), - (lambda image: channel_swap(image)) + (lambda image: (image.transpose(1, 2, 0) * 255).astype(np.uint8)) ] transform = py_vision.ComposeOp(transforms) From 1e869146e992e2f9ee3408dff09c4e1a85929b2e Mon Sep 17 00:00:00 2001 From: avakh Date: Thu, 25 Jun 2020 18:23:12 -0400 Subject: [PATCH 14/27] applying comments removing VOC --- ...random_resize_with_bbox_op_01_c_result.npz | Bin 0 -> 1654 bytes .../resize_with_bbox_op_01_c_result.npz | Bin 0 -> 1654 bytes .../dataset/test_random_resize_with_bbox.py | 333 +++++++---------- .../python/dataset/test_resize_with_bbox.py | 342 ++++++------------ 4 files changed, 237 insertions(+), 438 deletions(-) create mode 100644 tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz create mode 100644 tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_result.npz diff --git a/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..a5623304ba43eddb4656aafa663157636dae5528 GIT binary patch literal 1654 zcmbW&drVVT90%~*2P&NpaG>L%76jVPQV^jaAc!!WTxj7aP;iy?aj`gP`L(C84wwu? zKqrc*h(wl!Fat)mnEkO0wrtC?5Efr7vKX?FusB@ih{kN$!pL^d<#?AR+duo=^xoXx zJ)iS!o0EGAS(!YDS*{3XAoT5_1wV!%AAcr@vAEq8@j7R%Pr?KUO)`)?B-LXmE+{OM zFm+6W%FcOgZdIyM)o4^{m8wdY8@nye3YXi?k#?@7+QadEkHb>S@%QQktujHQRW>RA ze<_+}HmA3y)~B<%+?=k)TaByu?`!cXD%_mS<@8{;*M=3qERRw+?aNO=!ZK|j&6Zd} z)(&z6eri_2N?GY~SV4!{v9Fd>z)F*p|Cz=VSj94|q#gVbDAcUXB;mV_@QlG=$orEo z;lyTQSZ(bz!y2`T{6AGSmctyZHN!f!gK-Fv0S-Pd5W#wqCvSsi9m@fgi4e3L5R72M z(|~l?_|E_a!*kZo?10>mBOlKVeRFF(_o`YGqF-oO-RjsniDz|=>)VVMV(HLSxqasF) zKts**f?1mq>~dWFJuPb_B64oJU=X%p(26FO7#)H*YT^Z>-!wFG-Fn*aR(IETlGy1< z*p4AVG&_h%M36*{UNA*#x^Hj)tfqe{ZOei0OSzMfj3GrdJBdj}kVZ|qVDft}RmsDP zG>eBcGfDc*6R-hCI=*#N;C=pvEAWg}NMD^p2lyW;_b{R{nY4C>Svm zil&GdgkV242Lz+-nGNkS`x_s3&P~MJ&mVwd3?)|1q~KSd#p!fm3$AiGq0|f}aS4`@ znaU9yq%*xF%yjGPP~{z4YS{W;TYreV-%$o;42Q(@3SwSHV4=qP@5Y9~E*g%QN(6_g zaR_FmB=Lcy=%k<#F zP%9cB#*M&34His`D*ayM7p|$S8xKw4KaP9Bi=j?5^~Ai2z(>t%g1Of8WsK(9=r75o z^P?+c>m|^Dp;0tV#JrB+C^gN389>2&=G*fPx~VxUn5B-XosX09S4Sn72Xjlu9zqX>bD}v<%moC!)LazIUCY6~8!0Bsr|V)X yM}`{%p%23)E5~8zx9SR6i8M&|Ut0yg=4dsOx8vU%MTM29i~wII@VmLiD!o~R&G4qJRDP;ec)mcz0^+o!w*D>#wg zOhp(|VPq@|Ap{d}AtuYj8C{mek+?aGCd6zmE)JJDrZMgxWZS(jueU65|LnQ+F8AE~ zdB1J*=H7h4&p(7)s7UVE`A?3_25=nt1aTpp=<$?pG&=B)%E>D?6+BsrnwIuhe9im&3?{x_VO@iK_Ii&gj zrEc)EJH1tIpV98}NX9B}C9aUaujo^kdL+BcS&cnjJ5~d?FiPz#TX+h*z_oxXi?@Ma z8~7s#&%}a$cC0-pIcJHB8y7?ndTIDPT&N-4T2D;b%LLTm%H`wyvbzB{Yx%- z5;8F?wslxxiOxd)pNcARzXUH>VX4l+IcQ|CL(W@R3s z0T~eb?*IpNKnisViFN-W+sW6JxjW; z#b2)KpHE-A_opK1C?sP@QA{c^X$aDp$)Lu#X5ipS_hj$H$e$6l($F}(f?h^rso60U^!3Gx-+v5fy{YSXGCBuY7_t?!gP0ryuQIcfnv99KH0$JpzRujxx!Kd} z>R=a!T*U~)}O}~D#<-w}&{lDlN12cwv#q1#lA$X0Mz0??bZijVR zgUk;*?p%qvn>PRj7z%BYMJ=yB(dl$y5m&gJP-F#*vIL9COeF~Rv6)_{GktqKO!J*R zEqwVOO+UxoZ7&8ZhBuV-QexgjATncnwy|R&TO{D~uN(85+#Pr%cd> zp>> # Adds a box that covers the whole image. Good for testing edge cases - >>> de = de.map(input_columns=["image", "annotation"], - >>> output_columns=["image", "annotation"], - >>> operations=AddBadAnnotation(BoxType.OnEdge)) + Prints images and bboxes side by side with and without RandomresizeWithBBox Op applied, + applied on dynamically generated edge case, expected to pass. edge case is when bounding + box has dimensions as the image itself. """ + logger.info("test_random_resize_with_bbox_op_edge_c") + dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) - def __init__(self, box_type): - self.box_type = box_type + dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) - def __call__(self, img, bboxes): - """ - Used to generate erroneous bounding box examples on given img. - :param img: image where the bounding boxes are. - :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format - :return: bboxes with bad examples added - """ - height = img.shape[0] - width = img.shape[1] - if self.box_type == BoxType.WidthOverflow: - # use box that overflows on width - return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) + test_op = c_vision.RandomResizeWithBBox(500) - if self.box_type == BoxType.HeightOverflow: - # use box that overflows on height - return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) + dataVoc1 = dataVoc1.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + dataVoc2 = dataVoc2.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) - if self.box_type == BoxType.NegativeXY: - # use box with negative xy - return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + # maps to convert data into valid edge case data + dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - if self.box_type == BoxType.OnEdge: - # use box that covers the whole image - return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) - if self.box_type == BoxType.WrongShape: - # use box that covers the whole image - return img, np.array([[0, 0, width - 1]]).astype(np.uint32) - return img, bboxes + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp) -def check_bad_box(data, box_type, expected_error): +def test_random_resize_with_bbox_op_invalid_c(): + """ + Test RandomResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError + """ + logger.info("test_random_resize_with_bbox_op_invalid_c") + try: - test_op = c_vision.RandomResizeWithBBox(100) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) - data = data.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # map to use width overflow - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=AddBadAnnotation(box_type)) # Add column for "annotation" - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - for _, _ in enumerate(data.create_dict_iterator()): - break - except RuntimeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert expected_error in str(e) + # zero value for resize + c_vision.RandomResizeWithBBox(0) + except ValueError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "Input is not" in str(err) -def add_bounding_boxes(axis, bboxes): - """ - :param axis: axis to modify - :param bboxes: bounding boxes to draw on the axis - :return: None - """ - for bbox in bboxes: - rect = patches.Rectangle((bbox[0], bbox[1]), - bbox[2], bbox[3], - linewidth=1, edgecolor='r', facecolor='none') - # Add the patch to the Axes - axis.add_patch(rect) - - -def visualize(unaugmented_data, augment_data): - for idx, (un_aug_item, aug_item) in \ - enumerate(zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): - axis = plt.subplot(141) - plt.imshow(un_aug_item["image"]) - add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes - plt.title("Original" + str(idx + 1)) - logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) - - axis = plt.subplot(142) - plt.imshow(aug_item["image"]) - add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes - plt.title("Augmented" + str(idx + 1)) - logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") - plt.show() - - -def test_random_resize_with_bbox_op(plot=False): - """ - Test random_resize_with_bbox_op - """ - logger.info("Test random resize with bbox") - - # original images - data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - # augmented images - data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - data_original = data_original.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - data_augmented = data_augmented.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - # define map operations - test_op = c_vision.RandomResizeWithBBox(100) # input value being the target size of resizeOp - - data_augmented = data_augmented.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], operations=[test_op]) - if plot: - visualize(data_original, data_augmented) - - -def test_random_resize_with_bbox_invalid_bounds(): - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.WrongShape, "4 features") - - -def test_random_resize_with_bbox_invalid_size(): - """ - Test random_resize_with_bbox_op - """ - logger.info("Test random resize with bbox with invalid target size") - - # original images - data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - data = data.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - # negative target size as input try: - test_op = c_vision.RandomResizeWithBBox(-10) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + # one of the size values is zero + c_vision.RandomResizeWithBBox((0, 100)) - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + except ValueError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "Input is not" in str(err) - for _, _ in enumerate(data.create_dict_iterator()): - break - - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - print(e) - assert "Input is not" in str(e) - - # zero target size as input try: - test_op = c_vision.RandomResizeWithBBox(0) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + # negative value for resize + c_vision.RandomResizeWithBBox(-10) - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + except ValueError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "Input is not" in str(err) - for _, _ in enumerate(data.create_dict_iterator()): - break - - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) - - # invalid input shape try: - test_op = c_vision.RandomResizeWithBBox((10, 10, 10)) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + # invalid input shape + c_vision.RandomResizeWithBBox((100, 100, 100)) - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + except TypeError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "Size should be" in str(err) - for _, _ in enumerate(data.create_dict_iterator()): - break - except TypeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Size should be" in str(e) +def test_random_resize_with_bbox_op_bad_c(): + """ + Tests RandomResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors + """ + logger.info("test_random_resize_with_bbox_op_bad_c") + test_op = c_vision.RandomResizeWithBBox((400, 300)) + + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") + if __name__ == "__main__": - test_random_resize_with_bbox_op(plot=False) - test_random_resize_with_bbox_invalid_bounds() - test_random_resize_with_bbox_invalid_size() + test_random_resize_with_bbox_op_rand_c(plot_vis=False) + test_random_resize_with_bbox_op_edge_c(plot_vis=False) + test_random_resize_with_bbox_op_invalid_c() + test_random_resize_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/test_resize_with_bbox.py b/tests/ut/python/dataset/test_resize_with_bbox.py index 8b07f17f1a8..75500de6532 100644 --- a/tests/ut/python/dataset/test_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_resize_with_bbox.py @@ -15,281 +15,151 @@ """ Testing the resize with bounding boxes op in DE """ -from enum import Enum import numpy as np -import matplotlib.patches as patches -import matplotlib.pyplot as plt -import mindspore.dataset.transforms.vision.c_transforms as c_vision -from mindspore import log as logger import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as c_vision + +from mindspore import log as logger +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + save_and_check_md5 GENERATE_GOLDEN = False -DATA_DIR = "../data/dataset/testVOC2012" +DATA_DIR = "../data/dataset/testVOC2012_2" def fix_annotate(bboxes): """ + Fix annotations to format followed by mindspore. :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format """ - for bbox in bboxes: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp + for (i, box) in enumerate(bboxes): + bboxes[i] = np.roll(box, -1) return bboxes -class BoxType(Enum): +def test_resize_with_bbox_op_c(plot_vis=False): """ - Defines box types for test cases + Prints images and bboxes side by side with and without ResizeWithBBox Op applied, + tests with MD5 check, expected to pass """ - WidthOverflow = 1 - HeightOverflow = 2 - NegativeXY = 3 - OnEdge = 4 - WrongShape = 5 + logger.info("test_resize_with_bbox_op_c") + + # Load dataset + dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) + + dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) + + test_op = c_vision.ResizeWithBBox(200) + + dataVoc1 = dataVoc1.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + dataVoc2 = dataVoc2.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + # map to apply ops + dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) + + filename = "resize_with_bbox_op_01_c_result.npz" + save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp) -class AddBadAnnotation: # pylint: disable=too-few-public-methods +def test_resize_with_bbox_op_edge_c(plot_vis=False): """ - Used to add erroneous bounding boxes to object detection pipelines. - Usage: - >>> # Adds a box that covers the whole image. Good for testing edge cases - >>> de = de.map(input_columns=["image", "annotation"], - >>> output_columns=["image", "annotation"], - >>> operations=AddBadAnnotation(BoxType.OnEdge)) + Prints images and bboxes side by side with and without ResizeWithBBox Op applied, + applied on dynamically generated edge case, expected to pass. edge case is when bounding + box has dimensions as the image itself. """ + logger.info("test_resize_with_bbox_op_edge_c") + dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) - def __init__(self, box_type): - self.box_type = box_type + dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + decode=True, shuffle=False) - def __call__(self, img, bboxes): - """ - Used to generate erroneous bounding box examples on given img. - :param img: image where the bounding boxes are. - :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format - :return: bboxes with bad examples added - """ - height = img.shape[0] - width = img.shape[1] - if self.box_type == BoxType.WidthOverflow: - # use box that overflows on width - return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) + test_op = c_vision.ResizeWithBBox(500) - if self.box_type == BoxType.HeightOverflow: - # use box that overflows on height - return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) + dataVoc1 = dataVoc1.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + dataVoc2 = dataVoc2.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) - if self.box_type == BoxType.NegativeXY: - # use box with negative xy - return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + # maps to convert data into valid edge case data + dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - if self.box_type == BoxType.OnEdge: - # use box that covers the whole image - return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + # Test Op added to list of Operations here + dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[lambda img, bboxes: ( + img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op]) - if self.box_type == BoxType.WrongShape: - # use box that covers the whole image - return img, np.array([[0, 0, width - 1]]).astype(np.uint32) - return img, bboxes + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp) -def check_bad_box(data, box_type, expected_error): +def test_resize_with_bbox_op_invalid_c(): + """ + Test ResizeWithBBox Op on invalid constructor parameters, expected to raise ValueError + """ + logger.info("test_resize_with_bbox_op_invalid_c") + try: - test_op = c_vision.ResizeWithBBox(100) - data = data.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # map to use width overflow - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=AddBadAnnotation(box_type)) # Add column for "annotation" - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - for _, _ in enumerate(data.create_dict_iterator()): - break - except RuntimeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert expected_error in str(e) + # invalid interpolation value + c_vision.ResizeWithBBox(400, interpolation="invalid") + + except ValueError as err: + logger.info("Got an exception in DE: {}".format(str(err))) + assert "interpolation" in str(err) -def add_bounding_boxes(axis, bboxes): +def test_resize_with_bbox_op_bad_c(): """ - :param axis: axis to modify - :param bboxes: bounding boxes to draw on the axis - :return: None + Tests ResizeWithBBox Op with invalid bounding boxes, expected to catch multiple errors """ - for bbox in bboxes: - rect = patches.Rectangle((bbox[0], bbox[1]), - bbox[2], bbox[3], - linewidth=1, edgecolor='r', facecolor='none') - # Add the patch to the Axes - axis.add_patch(rect) + logger.info("test_resize_with_bbox_op_bad_c") + test_op = c_vision.ResizeWithBBox((200, 300)) - -def visualize(unaugmented_data, augment_data): - for idx, (un_aug_item, aug_item) in enumerate( - zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): - axis = plt.subplot(141) - plt.imshow(un_aug_item["image"]) - add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes - plt.title("Original" + str(idx + 1)) - logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) - - axis = plt.subplot(142) - plt.imshow(aug_item["image"]) - add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes - plt.title("Augmented" + str(idx + 1)) - logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") - plt.show() - - -def test_resize_with_bbox_op(plot=False): - """ - Test resize_with_bbox_op - """ - logger.info("Test resize with bbox") - - # original images - data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - # augmented images - data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - data_original = data_original.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - data_augmented = data_augmented.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - # define map operations - test_op = c_vision.ResizeWithBBox(100) # input value being the target size of resizeOp - - data_augmented = data_augmented.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], operations=[test_op]) - if plot: - visualize(data_original, data_augmented) - - -def test_resize_with_bbox_invalid_bounds(): data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - check_bad_box(data_voc2, BoxType.WrongShape, "4 features") + check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") -def test_resize_with_bbox_invalid_size(): - """ - Test resize_with_bbox_op - """ - logger.info("Test resize with bbox with invalid target size") - - # original images - data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - data = data.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - # negative target size as input - try: - test_op = c_vision.ResizeWithBBox(-10) - - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - - for _, _ in enumerate(data.create_dict_iterator()): - break - - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) - - # zero target size as input - try: - test_op = c_vision.ResizeWithBBox(0) - - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - - for _, _ in enumerate(data.create_dict_iterator()): - break - - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) - - # invalid input shape - try: - test_op = c_vision.ResizeWithBBox((10, 10, 10)) - - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - - for _, _ in enumerate(data.create_dict_iterator()): - break - - except TypeError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "Size should be" in str(e) - - -def test_resize_with_bbox_invalid_interpolation(): - """ - Test resize_with_bbox_op - """ - logger.info("Test resize with bbox with invalid interpolation size") - - # original images - data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - - data = data.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - - # invalid interpolation - try: - test_op = c_vision.ResizeWithBBox(100, interpolation="invalid") - - # map to apply ops - data = data.map(input_columns=["image", "annotation"], - output_columns=["image", "annotation"], - columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" - - for _, _ in enumerate(data.create_dict_iterator()): - break - - except ValueError as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "interpolation" in str(e) - if __name__ == "__main__": - test_resize_with_bbox_op(plot=False) - test_resize_with_bbox_invalid_bounds() - test_resize_with_bbox_invalid_size() - test_resize_with_bbox_invalid_interpolation() + test_resize_with_bbox_op_c(plot_vis=False) + test_resize_with_bbox_op_edge_c(plot_vis=False) + test_resize_with_bbox_op_invalid_c() + test_resize_with_bbox_op_bad_c() From 52c58735fc6711169cfba4992578abe2e960e4b4 Mon Sep 17 00:00:00 2001 From: dinghao Date: Sat, 27 Jun 2020 10:01:06 +0800 Subject: [PATCH 15/27] fix serving bugs --- mindspore/ccsrc/CMakeLists.txt | 13 +- mindspore/ccsrc/session/session.cc | 112 ++++++++++++------ mindspore/ccsrc/utils/log_adapter.cc | 32 +++-- serving/core/server.cc | 36 ++++-- serving/core/util/file_system_operation.cc | 5 +- serving/core/util/option_parser.cc | 40 ++++--- serving/core/util/option_parser.h | 3 +- serving/core/version_control/model.cc | 1 - .../version_control/version_controller.cc | 14 +-- .../core/version_control/version_controller.h | 1 - serving/cpp_example/ms_client.cc | 2 +- 11 files changed, 157 insertions(+), 102 deletions(-) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 8109e608c5c..cc5845cbf15 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -277,10 +277,11 @@ endif () if (USE_GLOG) target_link_libraries(inference PRIVATE mindspore::glog) -else() - if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init) - elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") - set_target_properties(inference PROPERTIES MACOSX_RPATH ON) - endif () endif() + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(inference PRIVATE -Wl,-init,common_log_init) +elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") + set_target_properties(inference PROPERTIES MACOSX_RPATH ON) +endif () + diff --git a/mindspore/ccsrc/session/session.cc b/mindspore/ccsrc/session/session.cc index 90e02b37ff1..ae70fc77aa5 100644 --- a/mindspore/ccsrc/session/session.cc +++ b/mindspore/ccsrc/session/session.cc @@ -33,9 +33,14 @@ namespace py = pybind11; namespace mindspore::inference { std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { - inference::Session::RegAllOp(); - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; + try { + inference::Session::RegAllOp(); + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return nullptr; + } } void ExitInference() { @@ -51,12 +56,17 @@ void ExitInference() { } std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { - auto session = std::make_shared(); - auto ret = session->Init(device, device_id); - if (ret != 0) { + try { + auto session = std::make_shared(); + auto ret = session->Init(device, device_id); + if (ret != 0) { + return nullptr; + } + return session; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CreatSession failed"; return nullptr; } - return session; } void Session::RegAllOp() { @@ -113,47 +123,71 @@ void Session::RegAllOp() { uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { MS_ASSERT(session_impl_ != nullptr); - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return graph_id; + try { + auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + return graph_id; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed"; + return static_cast(-1); + } } MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { - std::vector inTensors; - inTensors.resize(inputs.size()); - bool has_error = false; - std::transform(inputs.begin(), inputs.end(), inTensors.begin(), - [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; - has_error = true; - return nullptr; - } - auto tensor = static_cast(tensor_ptr.get()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; - has_error = true; - return nullptr; - } - return tensor->tensor(); - }); - if (has_error) { - MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; - std::vector> multiTensor; - return multiTensor; + try { + std::vector inTensors; + inTensors.resize(inputs.size()); + bool has_error = false; + std::transform(inputs.begin(), inputs.end(), inTensors.begin(), + [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; + has_error = true; + return nullptr; + } + auto tensor = static_cast(tensor_ptr.get()); + if (tensor == nullptr) { + MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; + has_error = true; + return nullptr; + } + return tensor->tensor(); + }); + if (has_error) { + MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; + std::vector> multiTensor; + return multiTensor; + } + VectorRef outputs; + session_impl_->RunGraph(graph_id, inTensors, &outputs); + + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed"; + return MultiTensor(); } - VectorRef outputs; - session_impl_->RunGraph(graph_id, inTensors, &outputs); - - return TransformVectorRefToMultiTensor(outputs); } - +namespace { +string AjustTargetName(const std::string &device) { + if (device == kAscendDevice) { + return std::string(kAscendDevice) + "Inference"; + } else { + MS_LOG(ERROR) << "Only support device Ascend right now"; + return ""; + } +} +} // namespace int Session::Init(const std::string &device, uint32_t device_id) { RegAllOp(); auto ms_context = MsContext::GetInstance(); ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_target(kAscendDevice); - session_impl_ = session::SessionFactory::Get().Create(device); + ms_context->set_device_id(device_id); + auto ajust_device = AjustTargetName(device); + if (ajust_device == "") { + return -1; + } + ms_context->set_device_target(device); + session_impl_ = session::SessionFactory::Get().Create(ajust_device); if (session_impl_ == nullptr) { MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; return -1; diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index d16fbead9bc..3588754dae1 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -463,7 +463,7 @@ void InitSubModulesLogLevel() { // set submodule's log level auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); - MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; + MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; LogConfigParser parser(submodule); auto configs = parser.Parse(); for (const auto &cfg : configs) { @@ -489,22 +489,14 @@ void InitSubModulesLogLevel() { } // namespace mindspore extern "C" { -// shared lib init hook #if defined(_WIN32) || defined(_WIN64) -__attribute__((constructor)) void mindspore_log_init(void) { +__attribute__((constructor)) void common_log_init(void) { #else -void mindspore_log_init(void) { +void common_log_init(void) { #endif #ifdef USE_GLOG // do not use glog predefined log prefix FLAGS_log_prefix = false; - static bool is_glog_initialzed = false; - if (!is_glog_initialzed) { -#if !defined(_WIN32) && !defined(_WIN64) - google::InitGoogleLogging("mindspore"); -#endif - is_glog_initialzed = true; - } // set default log level to WARNING if (mindspore::GetEnv("GLOG_v").empty()) { FLAGS_v = mindspore::WARNING; @@ -525,4 +517,22 @@ void mindspore_log_init(void) { #endif mindspore::InitSubModulesLogLevel(); } + +// shared lib init hook +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void mindspore_log_init(void) { +#else +void mindspore_log_init(void) { +#endif +#ifdef USE_GLOG + static bool is_glog_initialzed = false; + if (!is_glog_initialzed) { +#if !defined(_WIN32) && !defined(_WIN64) + google::InitGoogleLogging("mindspore"); +#endif + is_glog_initialzed = true; + } +#endif + common_log_init(); +} } diff --git a/serving/core/server.cc b/serving/core/server.cc index add9d16bee5..4a3a3b59eb5 100644 --- a/serving/core/server.cc +++ b/serving/core/server.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "mindspore/ccsrc/utils/log_adapter.h" #include "serving/ms_service.grpc.pb.h" @@ -40,7 +41,7 @@ namespace serving { using MSTensorPtr = std::shared_ptr; Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { - session_ = inference::MSSession::CreateSession(device + "Inference", device_id); + session_ = inference::MSSession::CreateSession(device, device_id); if (session_ == nullptr) { MS_LOG(ERROR) << "Creat Session Failed"; return FAILED; @@ -67,6 +68,7 @@ Status Session::Predict(const std::vector &inputs, inference::Multi MS_LOG(INFO) << "run Predict"; *outputs = session_->RunGraph(graph_id_, inputs); + MS_LOG(INFO) << "run Predict finished"; return SUCCESS; } @@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) { std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); char *graphBuf = ReadFile(file_name.c_str(), &size); if (graphBuf == nullptr) { - MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); return FAILED; } last_graph_ = inference::LoadModel(graphBuf, size, device_type_); + if (last_graph_ == nullptr) { + MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); + return FAILED; + } graph_id_ = session_->CompileGraph(last_graph_); - MS_LOG(INFO) << "Session Warmup"; + MS_LOG(INFO) << "Session Warmup finished"; return SUCCESS; } @@ -95,6 +101,9 @@ Status Session::Clear() { } namespace { +static const uint32_t uint32max = 0x7FFFFFFF; +std::promise exit_requested; + const std::map type2id_map{ {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, @@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { } TypeId type = iter->second; auto ms_tensor = std::shared_ptr(inference::MSTensor::CreateTensor(type, shape)); - memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); + memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size()); return ms_tensor; } @@ -166,10 +175,7 @@ void ClearEnv() { Session::Instance().Clear(); inference::ExitInference(); } -void HandleSignal(int sig) { - ClearEnv(); - exit(0); -} +void HandleSignal(int sig) { exit_requested.set_value(); } #ifdef ENABLE_D static rtContext_t g_ctx = nullptr; @@ -247,6 +253,7 @@ Status Server::BuildAndStart() { rtError_t rt_ret = rtCtxGetCurrent(&ctx); if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { MS_LOG(ERROR) << "the ascend device context is null"; + ClearEnv(); return FAILED; } g_ctx = ctx; @@ -258,6 +265,7 @@ Status Server::BuildAndStart() { auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); grpc::ServerBuilder builder; builder.SetOption(std::move(option)); + builder.SetMaxMessageSize(uint32max); // Listen on the given address without any authentication mechanism. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); // Register "service" as the instance through which we'll communicate with @@ -265,13 +273,15 @@ Status Server::BuildAndStart() { builder.RegisterService(&service); // Finally assemble the server. std::unique_ptr server(builder.BuildAndStart()); + auto grpc_server_run = [&server]() { server->Wait(); }; + std::thread serving_thread(grpc_server_run); MS_LOG(INFO) << "Server listening on " << server_address << std::endl; - - // Wait for the server to shutdown. Note that some other thread must be - // responsible for shutting down the server for this call to ever return. - server->Wait(); + auto exit_future = exit_requested.get_future(); + exit_future.wait(); + ClearEnv(); + server->Shutdown(); + serving_thread.join(); return SUCCESS; } - } // namespace serving } // namespace mindspore diff --git a/serving/core/util/file_system_operation.cc b/serving/core/util/file_system_operation.cc index a5143995dec..1af512a54c0 100644 --- a/serving/core/util/file_system_operation.cc +++ b/serving/core/util/file_system_operation.cc @@ -29,7 +29,6 @@ namespace mindspore { namespace serving { - char *ReadFile(const char *file, size_t *size) { if (file == nullptr) { MS_LOG(ERROR) << "file is nullptr"; @@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) { } std::vector GetAllSubDirs(const std::string &dir_path) { - DIR *dir; - struct dirent *ptr; + DIR *dir = nullptr; + struct dirent *ptr = nullptr; std::vector SubDirs; if ((dir = opendir(dir_path.c_str())) == NULL) { diff --git a/serving/core/util/option_parser.cc b/serving/core/util/option_parser.cc index 9cbd7eaee8f..c7f00e37338 100644 --- a/serving/core/util/option_parser.cc +++ b/serving/core/util/option_parser.cc @@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) { bool Option::ParseInt32(std::string *arg) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { - char extra; int32_t parsed_value; - if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { - std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; + try { + parsed_value = std::stoi(arg->data()); + } catch (std::invalid_argument) { + std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; return false; - } else { - *int32_default_ = parsed_value; } + *int32_default_ = parsed_value; return true; } - return false; } @@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) { bool Option::ParseFloat(std::string *arg) { if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { - char extra; float parsed_value; - if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { - std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; + try { + parsed_value = std::stof(arg->data()); + } catch (std::invalid_argument) { + std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; return false; - } else { - *float_default_ = parsed_value; } + *float_default_ = parsed_value; return true; } - return false; } @@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); } void Options::CreateOptions() { args_ = std::make_shared(); std::vector