diff --git a/example/resnet50_imagenet2012_THOR/config_imagenet.py b/example/resnet50_imagenet2012_THOR/config.py similarity index 100% rename from example/resnet50_imagenet2012_THOR/config_imagenet.py rename to example/resnet50_imagenet2012_THOR/config.py diff --git a/example/resnet50_imagenet2012_THOR/lr_generator.py b/example/resnet50_imagenet2012_THOR/lr_generator.py deleted file mode 100644 index a447daedd3..0000000000 --- a/example/resnet50_imagenet2012_THOR/lr_generator.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""learning rate generator""" -import math - -import numpy as np - - -def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): - """linear_warmup_lr""" - lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) - lr = float(init_lr) + lr_inc * current_step - return lr - - -def cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5): - """linear_warmup_lr""" - base_lr = lr - warmup_init_lr = 0 - total_steps = int(max_epoch * steps_per_epoch) - warmup_steps = int(warmup_epochs * steps_per_epoch) - decay_steps = total_steps - warmup_steps - lr_each_step = [] - for i in range(total_steps): - if i < warmup_steps: - lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) - else: - # linear_decay = (total_steps - i) / decay_steps - cosine_decay = 0.5 * (1 + math.cos(math.pi * i / decay_steps)) - decayed = cosine_decay - lr = base_lr * decayed - lr_each_step.append(lr) - return np.array(lr_each_step).astype(np.float32) - - -def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5): - """warmup_cosine_annealing_lr""" - base_lr = lr - warmup_init_lr = 0 - total_steps = int(max_epoch * steps_per_epoch * 0.99) - warmup_steps = int(warmup_epochs * steps_per_epoch) - decay_steps = total_steps - warmup_steps - lr_each_step = [] - for i in range(total_steps): - if i < warmup_steps: - lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) - else: - linear_decay = (total_steps - i) / decay_steps - cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * num_periods * i / decay_steps)) - decayed = linear_decay * cosine_decay - lr = base_lr * decayed + 0.000005 - lr_each_step.append(lr) - return np.array(lr_each_step).astype(np.float32) - - -def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): - """ - generate learning rate array - - Args: - global_step(int): total steps of the training - lr_init(float): init learning rate - lr_end(float): end learning rate - lr_max(float): max learning rate - warmup_epochs(int): number of warmup epochs - total_epochs(int): total epoch of training - steps_per_epoch(int): steps of one epoch - lr_decay_mode(string): learning rate decay mode, including steps, poly or default - - Returns: - np.array, learning rate array - """ - lr_each_step = [] - total_steps = steps_per_epoch * total_epochs - warmup_steps = steps_per_epoch * warmup_epochs - if lr_decay_mode == 'steps': - decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] - for i in range(total_steps): - if i < decay_epoch_index[0]: - lr = lr_max - elif i < decay_epoch_index[1]: - lr = lr_max * 0.1 - elif i < decay_epoch_index[2]: - lr = lr_max * 0.01 - else: - lr = lr_max * 0.001 - lr_each_step.append(lr) - elif lr_decay_mode == 'poly': - if warmup_steps != 0: - inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) - else: - inc_each_step = 0 - for i in range(total_steps): - if i < warmup_steps: - lr = float(lr_init) + inc_each_step * float(i) - else: - base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) - lr = float(lr_max) * base * base - if lr < 0.0: - lr = 0.0 - lr_each_step.append(lr) - else: - for i in range(total_steps): - if i < warmup_steps: - lr = lr_init + (lr_max - lr_init) * i / warmup_steps - else: - lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) - lr_each_step.append(lr) - - current_step = global_step - lr_each_step = np.array(lr_each_step).astype(np.float32) - learning_rate = lr_each_step[current_step:] - - return learning_rate diff --git a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py index b8efd9f682..474bccf42f 100644 --- a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py +++ b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py @@ -13,12 +13,10 @@ # limitations under the License. # ============================================================================ """Dataset help for minddata dataset""" -from mindspore import context from mindspore._checkparam import check_bool -from mindspore.nn.wrap import GetNextSingleOp -from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode -from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ - _construct_tensor_list, _to_full_shapes, _to_full_tensor +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode +from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ + _to_full_shapes from mindspore.train.parallel_utils import ParallelMode @@ -42,19 +40,9 @@ class DatasetHelper: >>> outputs = network(*inputs) """ - def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True): + def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): check_bool(dataset_sink_mode) - - iterclass = _DatasetIterGE - if not dataset_sink_mode: - iterclass = _DatasetIterFeed - elif not context.get_context("enable_ge"): - if context.get_context("enable_loop_sink"): - iterclass = _DatasetIterMSLoopSink - else: - iterclass = _DatasetIterMS - - self.iter = iterclass(dataset, first_order_iter) + self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) def __iter__(self): return self.iter.__iter__() @@ -85,12 +73,6 @@ class _DatasetIter: self.dataset = dataset dataset_types, dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes - # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to - # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number - # times the batch dimension of tensors for run - if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - device_num = _get_device_num() - self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num) def __iter__(self): self.ind = 0 @@ -109,83 +91,28 @@ class _DatasetIter: loop_count = 1 if hasattr(dataset, '__loop_size__'): loop_size = dataset.__loop_size__ + if dataset.get_dataset_size() % loop_size != 0: + raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' + f'loop_size {loop_size} are not matched.') loop_count = int(dataset.get_dataset_size() / loop_size) return loop_count class _DatasetIterMSLoopSink(_DatasetIter): - """Iter for context (enable_loop_sink=True)""" + """Iter for context (device_target=Ascend)""" - def __init__(self, dataset, first_order_iter): + def __init__(self, dataset, iter_first_order): super(_DatasetIterMSLoopSink, self).__init__(dataset) - # self.loop_count = self.get_loop_count(dataset) - loop_size = dataset.__loop_size__ + first_order_iter + loop_size = dataset.__loop_size__ + iter_first_order self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to + # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number + # times the batch dimension of tensors for run. Now only support LoopSink. + if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + device_num = _get_device_num() + self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) def op(): return tuple() self.op = op - - -class _DatasetIterMS(_DatasetIter): - """Iter for context (enable_loop_sink=False)""" - - def __init__(self, dataset, first_order_order): - super(_DatasetIterMS, self).__init__(dataset) - self.loop_count = dataset.get_dataset_size() - self.loop_size = 1 - queue_name = dataset.__ME_INITED__ - self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) - - -class _DatasetIterGE(_DatasetIter): - """Iter for ge""" - - def __init__(self, dataset): - super(_DatasetIterGE, self).__init__(dataset) - self.loop_count = self.get_loop_count(dataset) - parallel_mode = _get_parallel_mode() - self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - batch_expand_num = 1 - if self.need_to_full: - batch_expand_num = _get_device_num() - tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) - - def op(): - return tensor_list_run - - self.op = op - - -class _DatasetIterFeed: - """Iter for feed data""" - - def __init__(self, dataset, first_order_order): - self.dataset = dataset - self.device_num = _get_device_num() - self.global_rank = _get_global_rank() - self.repeat_count = dataset.get_repeat_count() - self.repeat_ind = 0 - self.loop_count = dataset.get_dataset_size() - self.ind = 0 - - parallel_mode = context.get_auto_parallel_context("parallel_mode") - self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - - def __iter__(self): - if self.repeat_ind % self.repeat_count == 0: - self.iter = self.dataset.__iter__() - - self.repeat_ind += 1 - self.ind = 0 - return self - - def __next__(self): - if self.ind >= self.loop_count: - raise StopIteration() - self.ind += 1 - data = self.iter.__next__() - if self.need_to_full: - return _to_full_tensor(data, self.device_num, self.global_rank) - return _to_tensor(data) diff --git a/example/resnet50_imagenet2012_THOR/model/model_thor.py b/example/resnet50_imagenet2012_THOR/model/model_thor.py index 613d15468f..f3418437a3 100644 --- a/example/resnet50_imagenet2012_THOR/model/model_thor.py +++ b/example/resnet50_imagenet2012_THOR/model/model_thor.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================ """Model.""" + +import numpy as np from mindspore import context from mindspore import log as logger +from mindspore import nn from mindspore._c_expression import init_exec_dataset from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool from mindspore.common import dtype as mstype @@ -28,9 +31,9 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_ from mindspore.train import amp from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks from mindspore.train.parallel_utils import ParallelMode -import mindspore.nn as nn -from second_order.dataset_helper import DatasetHelper -import numpy as np + +from model.dataset_helper import DatasetHelper + def _convert_type(types): """ @@ -69,7 +72,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): dataset_types, dataset_shapes, input_indexs, - phase=phase) + phase=phase, + need_run=False) class Model: @@ -123,7 +127,7 @@ class Model: >>> return out >>> >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> dataset = get_dataset() @@ -131,29 +135,35 @@ class Model: """ def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, - eval_indexes=None, amp_level="O0", frequency=278, **kwargs): + eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): self._network = network self._loss_fn = loss_fn self._optimizer = optimizer self._loss_scale_manager = None self._loss_scale_manager_set = False self._keep_bn_fp32 = True - self._frequency = frequency self._check_kwargs(kwargs) + self._amp_level = amp_level + self._process_amp_args(kwargs) + self._parallel_mode = _get_parallel_mode() + self._device_number = _get_device_num() + self._global_rank = _get_global_rank() + self._parameter_broadcast = _get_parameter_broadcast() + self._frequency = frequency + self._stop_epoch = stop_epoch + + self._train_network = self._build_train_network() + self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() + + def _process_amp_args(self, kwargs): + if self._amp_level == "O0": + self._keep_bn_fp32 = False if 'keep_batchnorm_fp32' in kwargs: self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] if 'loss_scale_manager' in kwargs: self._loss_scale_manager = kwargs['loss_scale_manager'] self._loss_scale_manager_set = True - self._amp_level = amp_level - self._parallel_mode = _get_parallel_mode() - self._device_number = _get_device_num() - self._global_rank = _get_global_rank() - self._parameter_broadcast = _get_parameter_broadcast() - - self._train_network = self._build_train_network() - self._build_eval_network(metrics, eval_network, eval_indexes) - self._build_predict_network() def _check_kwargs(self, kwargs): for arg in kwargs: @@ -180,6 +190,9 @@ class Model: elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() return network def _build_eval_network(self, metrics, eval_network, eval_indexes): @@ -198,14 +211,18 @@ class Model: else: if self._loss_fn is None: raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_indexes = [0, 1, 2] + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._eval_network.set_auto_parallel() + def _build_predict_network(self): """Build the network for prediction.""" self._predict_network = self._network if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): self._predict_network = _VirtualDatasetCell(self._network) + self._predict_network.set_auto_parallel() def _clear_metrics(self): """Clear metrics local values.""" @@ -246,6 +263,94 @@ class Model: scaling_sens /= self._device_number return scaling_sens + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order): + """Initializes dataset.""" + need_wrap = False + if dataset_sink_mode: + # remove later to deal with loop sink + if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ + and not context.get_context("enable_ge"): + need_wrap = True + + if not is_train: + dataset.__loop_size__ = 1 + + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) + + # remove later to deal with loop sink + if need_wrap: + network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) + network.set_train(is_train) + network.phase = phase + + return dataset_helper, network + + def init(self, train_dataset=None, valid_dataset=None): + """ + Initializes compute graphs and data graphs with sink mode. + + Note: + Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. + + Args: + train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be + initialized. Default: None. + valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will + be initialized, and `metrics` in `Model` can not be None. Default: None. + + Examples: + >>> train_dataset = get_train_dataset() + >>> valid_dataset = get_valid_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) + >>> model.init(train_dataset, valid_dataset) + >>> model.train(2, train_dataset) + >>> model.eval(valid_dataset) + """ + if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": + raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') + + if not train_dataset and not valid_dataset: + raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') + + _device_number_check(self._parallel_mode, self._device_number) + + if train_dataset: + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + self._train_network.set_train() + self._train_network.phase = 'train' + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + + train_dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True) + self._train_network = train_network + for inputs in train_dataset_helper: + self._train_network.compile(*inputs) + break + + if valid_dataset: + if not self._metric_fns: + raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') + + self._eval_network.set_train(False) + self._eval_network.phase = 'eval' + valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + for inputs in valid_dataset_helper: + self._eval_network.compile(*inputs) + break + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ Training. @@ -306,32 +411,27 @@ class Model: list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ - # remove later to deal with loop sink - iter_first_order = 277 + iter_first_order = self._frequency - 1 iter_second_order = 1 train_dataset.__loop_size__ = iter_second_order - need_wrap = False - if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ - and not context.get_context("enable_ge"): - need_wrap = True - - dataset_helper = DatasetHelper(train_dataset, iter_first_order) - # remove later to deal with loop sink - if need_wrap: - self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), - train_dataset.__ME_INITED__) - cb_params.train_network = self._train_network - self._train_network.set_train() - + dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True, + iter_first_order=iter_first_order) + self._train_network = train_network + cb_params.train_network = self._train_network cb_params.cur_step_num = 0 + loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False - has_do_train1_dataset = False - checkpoint_branch_one = True + has_do_dataset_init = False + switch_branch_one = True for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) @@ -339,18 +439,18 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: list_callback.step_begin(run_context) - if checkpoint_branch_one: + if switch_branch_one: cb_params.cur_step_num += loop_size - self._train_network.set_second_order(True) + self._train_network.add_flags_recursive(thor=True) self._train_network.phase = 'train0' else: cb_params.cur_step_num += iter_first_order - self._train_network.set_second_order(False) + self._train_network.add_flags_recursive(thor=False) self._train_network.phase = 'train1' - if not has_do_train1_dataset: + if not has_do_dataset_init: _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') - has_do_train1_dataset = True - checkpoint_branch_one = not checkpoint_branch_one + has_do_dataset_init = True + switch_branch_one = not switch_branch_one outputs = self._train_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) @@ -376,17 +476,21 @@ class Model: list_callback (_ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ - dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) + dataset_helper, _ = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=False) cb_params.cur_step_num = 0 run_context = RunContext(cb_params) - _callback_wrapper(list_callback, run_context, "begin") + list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False for i in range(epoch): cb_params.cur_epoch_num = i + 1 - _callback_wrapper(list_callback, run_context, "epoch_begin") + list_callback.epoch_begin(run_context) for next_element in dataset_helper: len_element = len(next_element) @@ -394,7 +498,7 @@ class Model: raise ValueError("when loss_fn is not None, train_dataset should" "return two elements, but got {}".format(len_element)) cb_params.cur_step_num += 1 - _callback_wrapper(list_callback, run_context, "step_begin") + list_callback.step_begin(run_context) overflow = False if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): @@ -408,19 +512,19 @@ class Model: overflow = np.all(overflow.asnumpy()) self._loss_scale_manager.update_loss_scale(overflow) - _callback_wrapper(list_callback, run_context, "step_end") + list_callback.step_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break train_dataset.reset() - _callback_wrapper(list_callback, run_context, "epoch_end") + list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break - _callback_wrapper(list_callback, run_context, "end") + list_callback.end(run_context) def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ @@ -452,7 +556,7 @@ class Model: Examples: >>> dataset = get_dataset() >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> loss_scale_manager = FixedLossScaleManager() >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) @@ -465,9 +569,6 @@ class Model: _device_number_check(self._parallel_mode, self._device_number) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) - if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"): - raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.") - self._train(epoch, train_dataset, callbacks=callbacks, @@ -485,25 +586,15 @@ class Model: Returns: Dict, returns the loss value & metrics values for the model in test mode. """ - _device_number_check(self._parallel_mode, self._device_number) - run_context = RunContext(cb_params) - # remove later to deal with loop sink - need_wrap = False - if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ - and not context.get_context("enable_ge"): - need_wrap = True - - valid_dataset.__loop_size__ = 1 - dataset_helper = DatasetHelper(valid_dataset) - - # remove later to deal with loop sink - if need_wrap: - self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), - valid_dataset.__ME_INITED__) - self._eval_network.set_train(mode=False) - self._eval_network.phase = 'eval' + dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + cb_params.eval_network = self._eval_network list_callback.begin(run_context) for inputs in dataset_helper: @@ -537,7 +628,11 @@ class Model: run_context = RunContext(cb_params) list_callback.begin(run_context) - dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) + dataset_helper, _ = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=False) for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) @@ -574,11 +669,12 @@ class Model: Examples: >>> dataset = get_dataset() >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model.eval(dataset) """ check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") diff --git a/example/resnet50_imagenet2012_THOR/model/resnet.py b/example/resnet50_imagenet2012_THOR/model/resnet.py index fb05341348..f3305022e8 100644 --- a/example/resnet50_imagenet2012_THOR/model/resnet.py +++ b/example/resnet50_imagenet2012_THOR/model/resnet.py @@ -14,22 +14,24 @@ # ============================================================================ """ResNet.""" import math - -import mindspore.nn as nn import numpy as np +import mindspore.nn as nn from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from second_order.thor_layer import Conv2d_Thor, Dense_Thor + +from model.thor_layer import Conv2d_Thor, Dense_Thor def calculate_gain(nonlinearity, param=None): + """calculate_gain""" linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + res = 0 if nonlinearity in linear_fns or nonlinearity == 'sigmoid': - return 1 + res = 1 elif nonlinearity == 'tanh': - return 5.0 / 3 + res = 5.0 / 3 elif nonlinearity == 'relu': - return math.sqrt(2.0) + res = math.sqrt(2.0) elif nonlinearity == 'leaky_relu': if param is None: negative_slope = 0.01 @@ -38,16 +40,17 @@ def calculate_gain(nonlinearity, param=None): negative_slope = param else: raise ValueError("negative_slope {} not a valid number".format(param)) - return math.sqrt(2.0 / (1 + negative_slope ** 2)) + res = math.sqrt(2.0 / (1 + negative_slope ** 2)) else: raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + return res def _calculate_fan_in_and_fan_out(tensor): + """_calculate_fan_in_and_fan_out""" dimensions = len(tensor) if dimensions < 2: raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") - if dimensions == 2: # Linear fan_in = tensor[1] fan_out = tensor[0] @@ -67,7 +70,6 @@ def _calculate_correct_fan(tensor, mode): valid_modes = ['fan_in', 'fan_out'] if mode not in valid_modes: raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) return fan_in if mode == 'fan_in' else fan_out @@ -93,8 +95,6 @@ def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, freq return Conv2d_Thor(in_channel, out_channel, kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, damping=damping, loss_scale=loss_scale, frequency=frequency) - # return nn.Conv2d(in_channel, out_channel, - # kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): @@ -125,7 +125,7 @@ def _bn_last(channel): def _fc(in_channel, out_channel, damping, loss_scale, frequency): weight_shape = (out_channel, in_channel) - weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)) + weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) @@ -133,15 +133,15 @@ def _fc(in_channel, out_channel, damping, loss_scale, frequency): class ResidualBlock(nn.Cell): """ ResNet V1 residual block definition. - + Args: in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. Default: 1. - + Returns: Tensor, output tensor. - + Examples: >>> ResidualBlock(3, 256, stride=2) """ @@ -210,7 +210,7 @@ class ResidualBlock(nn.Cell): class ResNet(nn.Cell): """ ResNet architecture. - + Args: block (Cell): Block for network. layer_nums (list): Numbers of block in different layers. @@ -220,7 +220,7 @@ class ResNet(nn.Cell): num_classes (int): The number of classes that the training images are belonging to. Returns: Tensor, output tensor. - + Examples: >>> ResNet(ResidualBlock, >>> [3, 4, 6, 3], @@ -290,17 +290,17 @@ class ResNet(nn.Cell): damping, loss_scale, frequency): """ Make stage network of ResNet. - + Args: block (Cell): Resnet block. layer_num (int): Layer number. in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. - + Returns: SequentialCell, the output layer. - + Examples: >>> _make_layer(ResidualBlock, 3, 128, 256, 2) """ @@ -321,7 +321,7 @@ class ResNet(nn.Cell): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) - c1, argmax = self.maxpool(x) + c1, _ = self.maxpool(x) c2 = self.layer1(c1) c3 = self.layer2(c2) @@ -338,13 +338,13 @@ class ResNet(nn.Cell): def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): """ Get ResNet50 neural network. - + Args: class_num (int): Class number. - + Returns: Cell, cell instance of ResNet50 neural network. - + Examples: >>> net = resnet50(10) """ diff --git a/example/resnet50_imagenet2012_THOR/run_distribute_train_new.sh b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh similarity index 92% rename from example/resnet50_imagenet2012_THOR/run_distribute_train_new.sh rename to example/resnet50_imagenet2012_THOR/run_distribute_train.sh index 3179a5b3a8..ae05c45dfe 100644 --- a/example/resnet50_imagenet2012_THOR/run_distribute_train_new.sh +++ b/example/resnet50_imagenet2012_THOR/run_distribute_train.sh @@ -51,6 +51,6 @@ do echo "start training for rank $RANK_ID, device $DEVICE_ID" env > env.log - python train_0517_1.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & + python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & cd .. done diff --git a/example/resnet50_imagenet2012_THOR/train.py b/example/resnet50_imagenet2012_THOR/train.py index 3843338a51..b98d13b8a0 100644 --- a/example/resnet50_imagenet2012_THOR/train.py +++ b/example/resnet50_imagenet2012_THOR/train.py @@ -17,7 +17,6 @@ import argparse import os import random -import mindspore.dataset.engine as de from mindspore import Tensor from mindspore import context from mindspore.communication.management import init @@ -25,19 +24,17 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.model import ParallelMode -from second_order.model_second_order import Model -from second_order.resnet import resnet50 -from second_order.thor import THOR +from model.model_thor import Model +from model.resnet import resnet50 +from model.thor import THOR import numpy as np -from config_imagenet import config +from config import config from crossentropy import CrossEntropy from dataset_imagenet import create_dataset -from lr_generator import warmup_cosine_annealing_lr random.seed(1) np.random.seed(1) -de.config.set_seed(1) parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') @@ -50,29 +47,29 @@ args_opt = parser.parse_args() device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=device_id) -context.set_context(enable_task_sink=True) -context.set_context(enable_loop_sink=True) -context.set_context(enable_mem_reuse=True) -def get_second_order_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): - """get_second_order_lr""" +def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): + """get_model_lr""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs for i in range(total_steps): epoch = (i + 1) / steps_per_epoch base = (1.0 - float(epoch) / total_epochs) ** decay lr_local = lr_init * base + if epoch >= 39: + lr_local = lr_local * 0.5 + if epoch >= 40: + lr_local = lr_local * 0.5 lr_each_step.append(lr_local) current_step = global_step lr_each_step = np.array(lr_each_step).astype(np.float32) - print("learning_rate_is=====", lr_each_step) learning_rate = lr_each_step[current_step:] return learning_rate -def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): - """get_second_order_damping""" +def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): + """get_model_damping""" damping_each_step = [] total_steps = steps_per_epoch * total_epochs for step in range(total_steps): @@ -83,26 +80,23 @@ def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs current_step = global_step damping_each_step = np.array(damping_each_step).astype(np.float32) damping_now = damping_each_step[current_step:] - print("damping_is=========", damping_now) return damping_now if __name__ == '__main__': - if args_opt.do_eval: - print("eval") - else: - if args_opt.run_distribute: - context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([80], "hccl_world_groupsum1") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") - init() - else: - print(" ") + if not args_opt.do_eval and args_opt.run_distribute: + context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + + init() epoch_size = config.epoch_size - damping = get_second_order_damping(0, 0.03, 0.87, 50, 5004) + damping = get_model_damping(0, 0.03, 0.87, 50, 5004) net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, frequency=config.frequency) @@ -115,17 +109,12 @@ if __name__ == '__main__': step_size = dataset.get_dataset_size() loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(warmup_cosine_annealing_lr(0.035, - step_size, - config.warmup_epochs, - 50, - config.T_max, - config.eta_min)) - opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, - config.momentum, damping, config.frequency, + lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004)) + opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), - filter(lambda x: 'spatial_norm' in x.name, net.get_parameters()), + filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), + filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), config.weight_decay, config.loss_scale) model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, diff --git a/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py b/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py new file mode 100644 index 0000000000..e2afa96a7d --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/batch_matmul_impl.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""batch_matmul_impl""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusBatchMatMul", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "batchmatmul.so", + "compute_cost": 10, + "kernel_name": "CusBatchMatMul", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x2", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): + """CusBatchMatMul""" + return diff --git a/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py b/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py new file mode 100644 index 0000000000..5c38dfc25d --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/cholesky_trsm.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusCholeskyTrsm""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusCholeskyTrsm", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "choleskytrsm.so", + "compute_cost": 10, + "kernel_name": "CusCholeskyTrsm", + "partial_flag": true, + "attr": [ + + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusCholeskyTrsm(input_x, output, kernel_name): + """CusCholeskyTrsm""" + return diff --git a/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py b/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py new file mode 100644 index 0000000000..b9a0d45273 --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/fused_abs_max1.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusFusedAbsMax1""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusFusedAbsMax1", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "fusedabsmax1.so", + "compute_cost": 10, + "kernel_name": "CusFusedAbsMax1", + "partial_flag": true, + "attr": [ + { + "name": "origin_shape", + "param_type": "required", + "type": "listInt", + "value": "all" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): + """CusFusedAbsMax1""" + return diff --git a/mindspore/ops/_op_impl/custom_op/img2col_impl.py b/mindspore/ops/_op_impl/custom_op/img2col_impl.py new file mode 100644 index 0000000000..5137d4d7e7 --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/img2col_impl.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusImg2ColNC1HWC0""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusImg2ColNC1HWC0", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "img2colnc1hwc0.so", + "compute_cost": 10, + "kernel_name": "CusImg2ColNC1HWC0", + "partial_flag": true, + "attr": [ + { + "name": "ksizes", + "param_type": "required", + "type": "listInt", + "value": "all" + }, + { + "name": "strides", + "param_type": "required", + "type": "listInt", + "value": "all" + }, + { + "name": "dilates", + "param_type": "required", + "type": "listInt", + "value": "all" + }, + { + "name": "padding", + "param_type": "required", + "type": "str", + "value": "all" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "NC1HWC0" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"): + """CusImg2ColNC1HWC0""" + return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py new file mode 100644 index 0000000000..300410eb4a --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/matmul_cube_dense_left.py @@ -0,0 +1,101 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + + +@op_info_register("""{ + "op_name": "CusMatMulCubeDenseLeft", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "matmulcubedenseleft.so", + "compute_cost": 10, + "kernel_name": "CusMatMulCubeDenseLeft", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x2", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 2, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x3", + "need_compile": false, + "param_type": "optional", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) +def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeDenseLeft""" + return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py new file mode 100644 index 0000000000..3da1593dfd --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_left_cast_impl.py @@ -0,0 +1,102 @@ +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + + +@op_info_register("""{ + "op_name": "CusMatMulCubeFraczLeftCast", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "matmulcubefraczleftcast.so", + "compute_cost": 10, + "kernel_name": "CusMatMulCubeFraczLeftCast", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float32" + ], + "format": [ + "FracZ" + ], + "name": "x2", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 2, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x3", + "need_compile": false, + "param_type": "optional", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FracZ" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) +def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="CusMatMulCubeFraczLeftCast"): + """CusMatMulCubeFraczLeftCast""" + return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py new file mode 100644 index 0000000000..7fc2ba35d1 --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/matmul_cube_fracz_right_mul_impl.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + + +@op_info_register("""{ + "op_name": "CusMatMulCubeFraczRightMul", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "matmulcubefraczrightmul.so", + "compute_cost": 10, + "kernel_name": "CusMatMulCubeFraczRightMul", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FracZ" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x2", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 2, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x3", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 3, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x4", + "need_compile": false, + "param_type": "optional", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "FracZ" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, + kernel_name="matmulcube"): + """CusMatMulCubeFraczRightMul""" + return diff --git a/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py b/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py new file mode 100644 index 0000000000..7c2d81e1d6 --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/matmul_cube_impl.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" +copyright 2020 Huawei Technologies Co., Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License == distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +matmul +""" +from __future__ import absolute_import + +from mindspore.ops.op_info_register import op_info_register +from topi.cce import util + +# General limitation of the size for input shape: 2**31 +SHAPE_SIZE_LIMIT = 2147483648 +NoneType = type(None) + + +@op_info_register("""{ + "op_name": "CusMatMulCube", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "matmulcube.so", + "compute_cost": 10, + "kernel_name": "CusMatMulCube", + "partial_flag": true, + "attr": [ + { + "name": "transpose_a", + "param_type": "required", + "type": "bool", + "value": "all" + }, + { + "name": "transpose_b", + "param_type": "required", + "type": "bool", + "value": "all" + } + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 1, + "dtype": [ + "float16" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "x2", + "need_compile": false, + "param_type": "required", + "shape": "all" + }, + { + "index": 2, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "x3", + "need_compile": false, + "param_type": "optional", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "FRACTAL_NZ" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements +@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) +def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): + """CusMatMulCube""" + return diff --git a/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py b/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py new file mode 100644 index 0000000000..32045e7ccb --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/matrix_combine_impl.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusMatrixCombine""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusMatrixCombine", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "matrixcombine.so", + "compute_cost": 10, + "kernel_name": "CusMatrixCombine", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float32" + ], + "format": [ + "DefaultFormat" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): + """CusMatrixCombine""" + return diff --git a/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py b/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py new file mode 100644 index 0000000000..c5aebe523d --- /dev/null +++ b/mindspore/ops/_op_impl/custom_op/transpose02314_impl.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CusTranspose02314""" +from mindspore.ops.op_info_register import op_info_register + + +@op_info_register("""{ + "op_name": "CusTranspose02314", + "imply_type": "TBE", + "fusion_type": "OPAQUE", + "async_flag": false, + "binfile_name": "transpose02314.so", + "compute_cost": 10, + "kernel_name": "CusTranspose02314", + "partial_flag": true, + "attr": [ + ], + "inputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "NC1HWC0" + ], + "name": "x1", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ], + "outputs": [ + { + "index": 0, + "dtype": [ + "float16" + ], + "format": [ + "DefaultFormat" + ], + "name": "y", + "need_compile": false, + "param_type": "required", + "shape": "all" + } + ] +}""") +def CusTranspose02314(input_x, output, kernel_name="transpose021354"): + """CusTranspose02314""" + return diff --git a/mindspore/ops/operations/thor_ops.py b/mindspore/ops/operations/thor_ops.py new file mode 100644 index 0000000000..23593a2630 --- /dev/null +++ b/mindspore/ops/operations/thor_ops.py @@ -0,0 +1,248 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""thor_ops""" +import mindspore as ms +from mindspore.ops import prim_attr_register, PrimitiveWithInfer +from mindspore.ops.composite import multitype_ops as C + + +class CusBatchMatMul(PrimitiveWithInfer): + """CusMatMulCube definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCube""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + + def get_bprop(self): + def bprop(x1, x2, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return data1_dtype + + +class CusCholeskyTrsm(PrimitiveWithInfer): + """CusCholeskyTrsm definition""" + + @prim_attr_register + def __init__(self): + """init CusCholeskyTrsm""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + + def infer_shape(self, data1_shape): + ll = [] + m, _ = data1_shape + if m >= 128: + ll = [m // 128, 128, 128] + else: + ll = [1, 64, 64] + return ll + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusFusedAbsMax1(PrimitiveWithInfer): + """CusCholeskyTrsm definition""" + + @prim_attr_register + def __init__(self, origin_shape=[-1, -1]): + """init CusCholeskyTrsm""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.origin_shape = origin_shape + + def get_bprop(self): + def bprop(x, out, dout): + return (C.zeros_like(x),) + + return bprop + + def infer_shape(self, data1_shape): + ll = [] + if len(data1_shape) == 2: + ll = [1,] + else: + ll = [32, 64] + return ll + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusImg2Col(PrimitiveWithInfer): + """CusImg2Col definition""" + + @prim_attr_register + def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): + """init CusImg2Col""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + self.ksizes = ksizes + self.strides = strides + self.dilates = dilates + self.mode = mode + + def get_bprop(self): + def bprop(x, out, dout): + return (C.zeros_like(x),) + + return bprop + + def infer_shape(self, data1_shape): + bs, c, h, w = data1_shape + _, stride_h, stride_w, _ = self.strides + _, k_w, k_h, _ = self.ksizes + # assert m == n + c0 = 16 + c1 = c // 16 + if c1 == 0: + c1 = 1 + shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0] + return shape + + def infer_dtype(self, data1_dtype): + return data1_dtype + + +class CusMatMulCubeDenseLeft(PrimitiveWithInfer): + """CusMatMulCube definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCube""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + + def get_bprop(self): + def bprop(x1, x2, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape): + return data2_shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float16")) + + +class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): + """CusMatMulCubeFraczRightMul definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCubeFraczRightMul""" + self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) + + def get_bprop(self): + def bprop(x1, x2, x3, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape, data3_shape): + return data1_shape + + def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float32")) + + +class CusMatMulCube(PrimitiveWithInfer): + """CusMatMulCube definition""" + + @prim_attr_register + def __init__(self, transpose_a=False, transpose_b=False): + """init CusMatMulCube""" + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + self.transpose_a = transpose_a + self.transpose_b = transpose_b + + def get_bprop(self): + def bprop(x1, x2, out, dout): + return (C.zeros_like(x1), C.zeros_like(x2)) + + return bprop + + def infer_shape(self, data1_shape, data2_shape): + # shape = [1, data1_shape[1], data2_shape[2], 16, 16] + # return shape + if self.transpose_a: + k1, m = data1_shape + else: + m, k1 = data1_shape + if self.transpose_b: + n, k2 = data2_shape + else: + k2, n = data2_shape + assert k1 == k2 + shape = [m, n] + return shape + + def infer_dtype(self, data1_dtype, data2_dtype): + return ms.common.dtype.tensor_type(getattr(ms, "float32")) + + +class CusMatrixCombine(PrimitiveWithInfer): + """CusMatMulCube definition""" + + @prim_attr_register + def __init__(self): + """init CusMatMulCube""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + def get_bprop(self): + def bprop(x, out, dout): + return (C.zeros_like(x),) + + return bprop + + def infer_shape(self, data_shape): + a, b, c = data_shape + shape = [a * b, a * c] + + return shape + + def infer_dtype(self, data_dtype): + return data_dtype + + +class CusTranspose02314(PrimitiveWithInfer): + """CusTranspose02314 definition""" + + @prim_attr_register + def __init__(self): + """init CusTranspose02314""" + self.init_prim_io_names(inputs=['x1'], outputs=['y']) + + def get_bprop(self): + def bprop(x, out, dout): + return (C.zeros_like(x),) + + return bprop + + def infer_shape(self, data1_shape): + assert len(data1_shape) == 4 + n, c, h, w = data1_shape + c0 = 16 + c1 = c // 16 + shape = (n * h * w, c1 * c0) + return shape + + def infer_dtype(self, data1_dtype): + return data1_dtype