fix bug: reduce_sum axis to input[1] in clip_by_norm_fission

This commit is contained in:
yiyanzhi_akane 2022-11-24 20:25:23 +08:00
parent 9203542060
commit 7f59f54c42
1 changed files with 37 additions and 6 deletions

View File

@ -17,6 +17,7 @@
#include "backend/common/pass/clip_by_norm_fission.h"
#include <algorithm>
#include "ir/anf.h"
#include "utils/ms_context.h"
#include "include/common/utils/anfalgo.h"
#include "backend/common/optimizer/helper.h"
@ -61,6 +62,23 @@ std::vector<int64_t> InferBroadcastShape(const std::vector<int64_t> &x_shape, co
}
return broadcast_shape;
}
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
MS_EXCEPTION_IF_NULL(tensor_input);
tensor_input->set_abstract(tensor_ptr->ToAbstract());
tensor_input = kernel_graph->NewValueNode(tensor_input);
kernel_graph->AddValueNodeToGraph(tensor_input);
tensor_input->set_scope(input_node->scope());
return tensor_input;
}
} // namespace
AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &inps,
@ -91,15 +109,17 @@ AnfNodePtr ClipByNormFission::CreateSquareNode(const FuncGraphPtr &func_graph, c
AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph, const AnfNodePtr &square,
const AnfNodePtr &clip_by_norm, const ShapeVector &shape_vec,
const TypeId &type_id) const {
auto reduce_sum = CreateCNodeBase(func_graph, {square}, kReduceSumOpName, square);
MS_EXCEPTION_IF_NULL(reduce_sum);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
bool use_asecend_backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice;
// Sync the attribute of `ClipByNorm` to `ReduceSum`
auto clip_by_norm_prim = common::AnfAlgo::GetCNodePrimitive(clip_by_norm);
MS_EXCEPTION_IF_NULL(clip_by_norm_prim);
auto axis_value = clip_by_norm_prim->GetAttr(kAttrAxis);
MS_EXCEPTION_IF_NULL(axis_value);
common::AnfAlgo::SetNodeAttr(kAttrAxis, axis_value, reduce_sum);
common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum);
// Get `axis` vector
const auto dim = shape_vec.size();
std::vector<int64_t> axis;
@ -116,7 +136,6 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph
MS_EXCEPTION(TypeError) << "For `" << prim::kPrimClipByNorm->name()
<< "`, the type of attribute `axis` is invalid.";
}
// Set abstract to `reduce_sum` op
int64_t ddim = SizeToLong(dim);
ShapeVector reduce_sum_output_shape = shape_vec;
for (const auto &idx : axis) {
@ -127,7 +146,19 @@ AnfNodePtr ClipByNormFission::CreateReduceSumNode(const FuncGraphPtr &func_graph
auto positive_idx = idx < 0 ? idx + ddim : idx;
reduce_sum_output_shape[LongToUlong(positive_idx)] = 1;
}
AnfNodePtr reduce_sum = nullptr;
if (use_asecend_backend) {
reduce_sum = CreateCNodeBase(func_graph, {square}, kReduceSumOpName, square);
MS_EXCEPTION_IF_NULL(reduce_sum);
common::AnfAlgo::SetNodeAttr(kAttrAxis, axis_value, reduce_sum);
} else {
// cpu, gpu backend
auto axis_node = NewValueNode(MakeValue<std::vector<int64_t>>(axis));
auto axis_tensor = ConvertValueToTensor(kernel_graph, axis_node);
reduce_sum = CreateCNodeBase(func_graph, {square, axis_tensor}, kReduceSumOpName, square);
MS_EXCEPTION_IF_NULL(reduce_sum);
}
common::AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), reduce_sum);
auto abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), reduce_sum_output_shape);
reduce_sum->set_abstract(abs);
return reduce_sum;