diff --git a/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc b/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc index 5d6b3166514..26f74a52972 100644 --- a/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc +++ b/mindspore/ccsrc/backend/common/optimizer/common_backend_optimization.cc @@ -33,7 +33,6 @@ #include "backend/common/pass/add_akg_kernel_attrs.h" #include "backend/common/pass/inplace_assign_for_custom_op.h" #include "backend/common/pass/flatten_concat_fission.h" -#include "backend/common/pass/clip_by_norm_fission.h" #include "backend/common/pass/add_dropout_attrs.h" #include "backend/common/optimizer/dynamic_shape/convert_custom_op.h" #include "backend/common/optimizer/dynamic_shape/link_custom_op.h" diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc index e5668011063..f56c794f437 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc @@ -105,7 +105,6 @@ #include "backend/common/pass/getitem_tuple.h" #include "backend/common/pass/optimize_dependence.h" #include "backend/common/pass/erase_visit_attr.h" -#include "backend/common/pass/clip_by_norm_fission.h" #include "plugin/device/ascend/optimizer/format_type/insert_cast.h" #include "plugin/device/ascend/optimizer/format_type/convert_unsupported_transnode_to_aicpu.h" #include "backend/common/pass/eliminate_redundant_op.h" diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc index dea872c3ef0..9c0a2b7bfd4 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc @@ -38,7 +38,7 @@ #include "common/graph_kernel/graph_kernel_flags.h" #include "plugin/device/gpu/hal/profiler/gpu_profiling.h" #include "plugin/device/gpu/hal/profiler/gpu_profiling_utils.h" -#include "backend/common/pass/clip_by_norm_fission.h" +#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h" #include "backend/common/session/kernel_graph.h" #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_session.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_session.cc index 12da323299e..2305432e130 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_session.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_session.cc @@ -49,6 +49,7 @@ #include "plugin/device/gpu/optimizer/add_relu_grad_v2_fusion.h" #include "plugin/device/gpu/optimizer/matmul_biasadd_fusion.h" #include "plugin/device/gpu/optimizer/neighbor_exchange_v2_fusion.h" +#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h" #ifdef ENABLE_GPU_INFER #include "plugin/device/gpu/optimizer/trt_pass/graph_converter.h" #endif diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/optimizer.h b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/optimizer.h index 70765ce8a7d..7ea21d59aac 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/optimizer.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/optimizer.h @@ -56,6 +56,7 @@ #include "plugin/device/gpu/optimizer/insert_cast_gpu.h" #include "plugin/device/gpu/optimizer/neighbor_exchange_v2_fusion.h" #include "plugin/device/gpu/optimizer/bias_dropout_add_fusion.h" +#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h" #include "backend/common/pass/insert_type_transform_op.h" #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/backend/common/pass/clip_by_norm_fission.cc b/mindspore/ccsrc/plugin/device/gpu/optimizer/clip_by_norm_fission.cc similarity index 92% rename from mindspore/ccsrc/backend/common/pass/clip_by_norm_fission.cc rename to mindspore/ccsrc/plugin/device/gpu/optimizer/clip_by_norm_fission.cc index b98f0a4dc71..6ff7f859dfb 100644 --- a/mindspore/ccsrc/backend/common/pass/clip_by_norm_fission.cc +++ b/mindspore/ccsrc/plugin/device/gpu/optimizer/clip_by_norm_fission.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/common/pass/clip_by_norm_fission.h" +#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h" #include #include "ir/anf.h" #include "include/common/utils/anfalgo.h" @@ -61,6 +61,23 @@ std::vector InferBroadcastShape(const std::vector &x_shape, co } return broadcast_shape; } + +AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(input_node); + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast()); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor_input = std::make_shared(tensor_ptr); + MS_EXCEPTION_IF_NULL(tensor_input); + tensor_input->set_abstract(tensor_ptr->ToAbstract()); + tensor_input = kernel_graph->NewValueNode(tensor_input); + kernel_graph->AddValueNodeToGraph(tensor_input); + tensor_input->set_scope(input_node->scope()); + return tensor_input; +} } // namespace AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, const std::vector &inps, @@ -91,15 +108,14 @@ AnfNodePtr ClipByNormFission::CreateSquareNode(const FuncGraphPtr &func_graph, c AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph, const AnfNodePtr &square, const AnfNodePtr &clip_by_norm, const ShapeVector &shape_vec, const TypeId &type_id) const { - auto reduce_sum = CreateCNodeBase(func_graph, {square}, kReduceSumOpName, square); - MS_EXCEPTION_IF_NULL(reduce_sum); + auto kernel_graph = func_graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); // Sync the attribute of `ClipByNorm` to `ReduceSum` auto clip_by_norm_prim = common::AnfAlgo::GetCNodePrimitive(clip_by_norm); MS_EXCEPTION_IF_NULL(clip_by_norm_prim); auto axis_value = clip_by_norm_prim->GetAttr(kAttrAxis); MS_EXCEPTION_IF_NULL(axis_value); - common::AnfAlgo::SetNodeAttr(kAttrAxis, axis_value, reduce_sum); - common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum); + // Get `axis` vector const auto dim = shape_vec.size(); std::vector axis; @@ -116,7 +132,6 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name() << "`, the type of attribute `axis` is invalid."; } - // Set abstract to `reduce_sum` op int64_t ddim = SizeToLong(dim); ShapeVector reduce_sum_output_shape = shape_vec; for (const auto &idx : axis) { @@ -127,7 +142,14 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph auto positive_idx = idx < 0 ? idx + ddim : idx; reduce_sum_output_shape[LongToUlong(positive_idx)] = 1; } + AnfNodePtr reduce_sum = nullptr; + auto axis_node = NewValueNode(MakeValue>(axis)); + auto axis_tensor = ConvertValueToTensor(kernel_graph, axis_node); + reduce_sum = CreateCNodeBase(func_graph, {square, axis_tensor}, kReduceSumOpName, square); + MS_EXCEPTION_IF_NULL(reduce_sum); + + common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum); auto abs = std::make_shared(TypeIdToType(type_id), reduce_sum_output_shape); reduce_sum->set_abstract(abs); return reduce_sum; diff --git a/mindspore/ccsrc/backend/common/pass/clip_by_norm_fission.h b/mindspore/ccsrc/plugin/device/gpu/optimizer/clip_by_norm_fission.h similarity index 100% rename from mindspore/ccsrc/backend/common/pass/clip_by_norm_fission.h rename to mindspore/ccsrc/plugin/device/gpu/optimizer/clip_by_norm_fission.h