From 48db7f8c4f2ba0407b6f3be6a0b742d46594d1a4 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 22 Sep 2020 10:13:19 +0800 Subject: [PATCH] gpu change bncast --- .../optimizer/gpu/replace_bn_cast_fusion.cc | 89 --------------- .../optimizer/gpu/replace_bn_cast_fusion.h | 58 ---------- .../gpu/replace_bn_grad_cast_fusion.cc | 108 ------------------ .../gpu/replace_bn_grad_cast_fusion.h | 56 --------- .../ccsrc/backend/session/gpu_session.cc | 4 - model_zoo/official/cv/googlenet/README.md | 2 +- model_zoo/official/cv/googlenet/train.py | 8 +- model_zoo/official/cv/resnet/train.py | 2 +- model_zoo/official/cv/resnet_thor/train.py | 8 +- .../official/cv/yolov3_darknet53/train.py | 2 +- 10 files changed, 7 insertions(+), 330 deletions(-) delete mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc delete mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h delete mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc delete mode 100644 mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc deleted file mode 100644 index d67d20adf26..00000000000 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/gpu/replace_bn_cast_fusion.h" - -#include -#include -#include - -#include "backend/session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "backend/optimizer/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef ReplaceBNCastFusion::DefinePattern() const { - VectorRef in_cast = VectorRef({prim::kPrimCast, x_}); - VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNormEx, in_cast, scale_, bias_, mean_, var_}); - VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_}); - return tupleget; -} - -const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto fbn2 = AnfAlgo::GetInputNode(utils::cast(node), 0); - auto x_after = AnfAlgo::GetInputNode(utils::cast(fbn2), 0); - auto x_before = AnfAlgo::GetInputNode(utils::cast(x_after), 0); - MS_EXCEPTION_IF_NULL(fbn2); - MS_EXCEPTION_IF_NULL(x_after); - MS_EXCEPTION_IF_NULL(x_before); - // only deal with x_after with fp32: x 16->32->bn->16->32 - if (AnfAlgo::GetOutputInferDataType(x_after, 0) == kNumberTypeFloat16) { - return nullptr; - } - std::vector outputs_type; - std::vector> outputs_shape; - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto outlist = GetRealNodeUsedList(graph, fbn2); - bool changed = false; - for (size_t i = 0; i < outlist->size(); i++) { - auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(i).first), 1); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - if (item_idx == 0) { - auto cast = GetRealNodeUsedList(graph, outlist->at(i).first); - if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { - continue; - } - manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(i).first)); - outputs_type.push_back(kNumberTypeFloat16); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0)); - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get()); - changed = true; - } - } - if (!changed) { - return nullptr; - } - manager->Replace(utils::cast(x_after), utils::cast(x_before)); - outputs_type.clear(); - outputs_shape.clear(); - auto output_num = AnfAlgo::GetOutputTensorNum(fbn2); - for (size_t i = 0; i < output_num; i++) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(fbn2, i)); - } - outputs_type[0] = kNumberTypeFloat16; - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get()); - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h deleted file mode 100644 index 6b1e2ad7b12..00000000000 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_cast_fusion.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ - -#include -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ReplaceBNCastFusion : public PatternProcessPass { - public: - explicit ReplaceBNCastFusion(bool multigraph = true) : PatternProcessPass("replace_bn_cast", multigraph) { - x_ = std::make_shared(); - scale_ = std::make_shared(); - bias_ = std::make_shared(); - mean_ = std::make_shared(); - var_ = std::make_shared(); - y_ = std::make_shared(); - running_mean_ = std::make_shared(); - running_var_ = std::make_shared(); - save_mean_ = std::make_shared(); - save_var_ = std::make_shared(); - index_ = std::make_shared(); - } - ~ReplaceBNCastFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr x_; - VarPtr scale_; - VarPtr bias_; - VarPtr mean_; - VarPtr var_; - VarPtr y_; - VarPtr running_mean_; - VarPtr running_var_; - VarPtr save_mean_; - VarPtr save_var_; - VarPtr index_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc deleted file mode 100644 index 01c915142bf..00000000000 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.cc +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" - -#include -#include -#include - -#include "backend/session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "backend/optimizer/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef ReplaceBNGradCastFusion::DefinePattern() const { - VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_}); - VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGradEx, dy_cast, x_, scale_, mean_, var_, reserve_}); - VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_}); - return tupleget; -} - -const void HandleOutput(const FuncGraphPtr &graph, const mindspore::CNodePtr &kernel) { - auto outlist = GetRealNodeUsedList(graph, kernel); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (size_t j = 0; j < outlist->size(); j++) { - std::vector outputs_type; - std::vector> outputs_shape; - auto index_node = AnfAlgo::GetInputNode(utils::cast(outlist->at(j).first), 1); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - if (item_idx == 0) { - auto cast = GetRealNodeUsedList(graph, outlist->at(j).first); - if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") { - continue; - } - manager->Replace(utils::cast(cast->at(0).first), utils::cast(outlist->at(j).first)); - outputs_type.push_back(kNumberTypeFloat16); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(j).first, 0)); - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(j).first.get()); - } - } -} - -const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - - auto fbn2g = AnfAlgo::GetInputNode(utils::cast(node), 0); - auto dy_after = AnfAlgo::GetInputNode(utils::cast(fbn2g), 0); - auto dy_before = AnfAlgo::GetInputNode(utils::cast(dy_after), 0); - auto x_ = AnfAlgo::GetInputNode(utils::cast(fbn2g), 1); - MS_EXCEPTION_IF_NULL(x_); - // if x_type is fp32, the cast is necessary or dy_afer is fp32: dy 16->32->bng->16->32. - if (AnfAlgo::GetOutputInferDataType(x_, 0) == kNumberTypeFloat32 || - AnfAlgo::GetOutputInferDataType(dy_after, 0) == kNumberTypeFloat16) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(fbn2g); - MS_EXCEPTION_IF_NULL(dy_after); - MS_EXCEPTION_IF_NULL(dy_before); - std::vector outputs_type; - std::vector> outputs_shape; - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - // 1. get all of the fusedbatchnormgrad nodes connected after dy_after. - auto fbn2g_all = GetRealNodeUsedList(graph, dy_after); - for (size_t i = 0; i < fbn2g_all->size(); i++) { - outputs_type.clear(); - outputs_shape.clear(); - auto kernel = utils::cast(fbn2g_all->at(i).first); - auto kernel_name = AnfAlgo::GetCNodeName(kernel); - // 2. deal all of the fusedbatchnormgrad, change the data type. - if (kernel_name == AnfAlgo::GetCNodeName(utils::cast(fbn2g))) { - auto output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t j = 0; j < output_num; j++) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel, j)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(kernel, j)); - } - outputs_type[0] = kNumberTypeFloat16; - AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, kernel.get()); - } - // 3. handle the output of fusedbatchnormgrad: tuplegetitem - HandleOutput(graph, kernel); - } - manager->Replace(utils::cast(dy_after), utils::cast(dy_before)); - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h deleted file mode 100644 index 968ed528486..00000000000 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_bn_grad_cast_fusion.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ - -#include -#include "backend/optimizer/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ReplaceBNGradCastFusion : public PatternProcessPass { - public: - explicit ReplaceBNGradCastFusion(bool multigraph = true) : PatternProcessPass("replace_bn_grad_cast", multigraph) { - dy_ = std::make_shared(); - x_ = std::make_shared(); - scale_ = std::make_shared(); - mean_ = std::make_shared(); - var_ = std::make_shared(); - dx_ = std::make_shared(); - bn_scale_ = std::make_shared(); - bn_bias_ = std::make_shared(); - index_ = std::make_shared(); - reserve_ = std::make_shared(); - } - ~ReplaceBNGradCastFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr dy_; - VarPtr x_; - VarPtr scale_; - VarPtr mean_; - VarPtr var_; - VarPtr dx_; - VarPtr bn_scale_; - VarPtr bn_bias_; - VarPtr index_; - VarPtr reserve_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_BN_GRAD_CAST_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 294a97e4f97..a1d3a00724d 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -28,8 +28,6 @@ #include "backend/optimizer/gpu/adam_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" -#include "backend/optimizer/gpu/replace_bn_cast_fusion.h" -#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" @@ -82,8 +80,6 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto pm = std::make_shared(); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); diff --git a/model_zoo/official/cv/googlenet/README.md b/model_zoo/official/cv/googlenet/README.md index 422ebd676af..36782ed6cef 100644 --- a/model_zoo/official/cv/googlenet/README.md +++ b/model_zoo/official/cv/googlenet/README.md @@ -447,7 +447,7 @@ If you need to use the trained model to perform inference on multiple hardware p Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) # Set callbacks config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, diff --git a/model_zoo/official/cv/googlenet/train.py b/model_zoo/official/cv/googlenet/train.py index cee5a8693d7..ed20f995964 100644 --- a/model_zoo/official/cv/googlenet/train.py +++ b/model_zoo/official/cv/googlenet/train.py @@ -197,12 +197,8 @@ if __name__ == '__main__': else: loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) - if device_target == "Ascend": - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) - else: # GPU - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) time_cb = TimeMonitor(data_size=batch_num) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 3c29241214a..6e6d221c72a 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -168,7 +168,7 @@ if __name__ == '__main__': loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) # Mixed precision model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True) + amp_level="O2", keep_batchnorm_fp32=False) else: ## fp32 training opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) diff --git a/model_zoo/official/cv/resnet_thor/train.py b/model_zoo/official/cv/resnet_thor/train.py index 6e31bf87a9f..ec292e7eed8 100644 --- a/model_zoo/official/cv/resnet_thor/train.py +++ b/model_zoo/official/cv/resnet_thor/train.py @@ -124,12 +124,8 @@ if __name__ == '__main__': filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), config.weight_decay, config.loss_scale) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - if target == "Ascend": - model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, - keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) - else: - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency) + model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, + keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) # define callbacks time_cb = TimeMonitor(data_size=step_size) diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index bbeb0c14657..309efb5e411 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -215,7 +215,7 @@ def train(): loss_scale_value = 1.0 loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False) network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale, - level="O2", keep_batchnorm_fp32=True) + level="O2", keep_batchnorm_fp32=False) keep_loss_fp32(network) else: network = TrainingWrapper(network, opt)