!46061 fix bug: reduce_sum axis to input[1] in clip_by_norm_fission
Merge pull request !46061 from Yanzhi_YI/master
This commit is contained in:
commit
a685b6bca2
|
@ -33,7 +33,6 @@
|
|||
#include "backend/common/pass/add_akg_kernel_attrs.h"
|
||||
#include "backend/common/pass/inplace_assign_for_custom_op.h"
|
||||
#include "backend/common/pass/flatten_concat_fission.h"
|
||||
#include "backend/common/pass/clip_by_norm_fission.h"
|
||||
#include "backend/common/pass/add_dropout_attrs.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/convert_custom_op.h"
|
||||
#include "backend/common/optimizer/dynamic_shape/link_custom_op.h"
|
||||
|
|
|
@ -105,7 +105,6 @@
|
|||
#include "backend/common/pass/getitem_tuple.h"
|
||||
#include "backend/common/pass/optimize_dependence.h"
|
||||
#include "backend/common/pass/erase_visit_attr.h"
|
||||
#include "backend/common/pass/clip_by_norm_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/insert_cast.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||
#include "backend/common/pass/eliminate_redundant_op.h"
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "plugin/device/gpu/hal/profiler/gpu_profiling.h"
|
||||
#include "plugin/device/gpu/hal/profiler/gpu_profiling_utils.h"
|
||||
#include "backend/common/pass/clip_by_norm_fission.h"
|
||||
#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
|
|
@ -49,6 +49,7 @@
|
|||
#include "plugin/device/gpu/optimizer/add_relu_grad_v2_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/matmul_biasadd_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/neighbor_exchange_v2_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h"
|
||||
#ifdef ENABLE_GPU_INFER
|
||||
#include "plugin/device/gpu/optimizer/trt_pass/graph_converter.h"
|
||||
#endif
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
#include "plugin/device/gpu/optimizer/insert_cast_gpu.h"
|
||||
#include "plugin/device/gpu/optimizer/neighbor_exchange_v2_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/bias_dropout_add_fusion.h"
|
||||
#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h"
|
||||
#include "backend/common/pass/insert_type_transform_op.h"
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_OPTIMIZER_H_
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/common/pass/clip_by_norm_fission.h"
|
||||
#include "plugin/device/gpu/optimizer/clip_by_norm_fission.h"
|
||||
#include <algorithm>
|
||||
#include "ir/anf.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
@ -61,6 +61,23 @@ std::vector<int64_t> InferBroadcastShape(const std::vector<int64_t> &x_shape, co
|
|||
}
|
||||
return broadcast_shape;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(tensor_input);
|
||||
tensor_input->set_abstract(tensor_ptr->ToAbstract());
|
||||
tensor_input = kernel_graph->NewValueNode(tensor_input);
|
||||
kernel_graph->AddValueNodeToGraph(tensor_input);
|
||||
tensor_input->set_scope(input_node->scope());
|
||||
return tensor_input;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &inps,
|
||||
|
@ -91,15 +108,14 @@ AnfNodePtr ClipByNormFission::CreateSquareNode(const FuncGraphPtr &func_graph, c
|
|||
AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph, const AnfNodePtr &square,
|
||||
const AnfNodePtr &clip_by_norm, const ShapeVector &shape_vec,
|
||||
const TypeId &type_id) const {
|
||||
auto reduce_sum = CreateCNodeBase(func_graph, {square}, kReduceSumOpName, square);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// Sync the attribute of `ClipByNorm` to `ReduceSum`
|
||||
auto clip_by_norm_prim = common::AnfAlgo::GetCNodePrimitive(clip_by_norm);
|
||||
MS_EXCEPTION_IF_NULL(clip_by_norm_prim);
|
||||
auto axis_value = clip_by_norm_prim->GetAttr(kAttrAxis);
|
||||
MS_EXCEPTION_IF_NULL(axis_value);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, axis_value, reduce_sum);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum);
|
||||
|
||||
// Get `axis` vector
|
||||
const auto dim = shape_vec.size();
|
||||
std::vector<int64_t> axis;
|
||||
|
@ -116,7 +132,6 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph
|
|||
MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name()
|
||||
<< "`, the type of attribute `axis` is invalid.";
|
||||
}
|
||||
// Set abstract to `reduce_sum` op
|
||||
int64_t ddim = SizeToLong(dim);
|
||||
ShapeVector reduce_sum_output_shape = shape_vec;
|
||||
for (const auto &idx : axis) {
|
||||
|
@ -127,7 +142,14 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph
|
|||
auto positive_idx = idx < 0 ? idx + ddim : idx;
|
||||
reduce_sum_output_shape[LongToUlong(positive_idx)] = 1;
|
||||
}
|
||||
AnfNodePtr reduce_sum = nullptr;
|
||||
|
||||
auto axis_node = NewValueNode(MakeValue<std::vector<int64_t>>(axis));
|
||||
auto axis_tensor = ConvertValueToTensor(kernel_graph, axis_node);
|
||||
reduce_sum = CreateCNodeBase(func_graph, {square, axis_tensor}, kReduceSumOpName, square);
|
||||
MS_EXCEPTION_IF_NULL(reduce_sum);
|
||||
|
||||
common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), reduce_sum_output_shape);
|
||||
reduce_sum->set_abstract(abs);
|
||||
return reduce_sum;
|
Loading…
Reference in New Issue