fixed graph_kernel broadcastto bug

This commit is contained in:
huoxinyou 2024-02-29 10:53:58 +08:00
parent 79b88f6dad
commit c298f443ea
9 changed files with 33 additions and 12 deletions

2
akg

@ -1 +1 @@
Subproject commit a235ba7365d87e4c884f93c8da934d7c313b61d4
Subproject commit 749589da461f1c339e3ce94a07b638c99b632c45

View File

@ -44,7 +44,7 @@ const std::set<std::string> &GetConvertInputAttrOps() {
prim::kPrimCumSum->name(), prim::kPrimArgmin->name(), prim::kPrimArgmax->name(),
prim::kPrimBiasAdd->name(), prim::kPrimBiasAddGrad->name(), prim::kPrimLayerNorm->name(),
prim::kPrimLayerNormGrad->name(), prim::kPrimLogSoftmax->name(), prim::kPrimLogSoftmaxGrad->name(),
prim::kPrimBroadcastTo->name(), prim::kPrimAdamWeightDecay->name(), prim::kPrimStridedSlice->name(),
prim::kPrimStridedSlice->name(), prim::kPrimAdamWeightDecay->name(),
};
return convert_input_attr_ops;
}

View File

@ -905,6 +905,11 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
for (auto node : func_graph->GetOrderedCnodes()) {
if (AnfUtils::IsGraphKernel(node)) {
auto sub_graph = GetCNodeFuncGraph(node);
if (auto type = sub_graph->get_attr("composite_type")) {
if (GetValue<std::string>(type) == "inplace_assign_builder") {
continue;
}
}
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs = cnode->inputs();
inner::LiteGraphPtr lg = GkUtils::AnfGraph2LiteGraph(sub_graph);

View File

@ -39,6 +39,7 @@ const std::unordered_map<std::string, HashSet<size_t>> &ValueDependOpUtils::GetO
{prim::kPrimReduceSum->name(), {1}},
{prim::kPrimTranspose->name(), {1}},
{prim::kPrimTile->name(), {1}},
{prim::kPrimBroadcastTo->name(), {1}},
{prim::kPrimReduceMean->name(), {1}},
{prim::kPrimSlice->name(), {1, 2}},
{prim::kPrimStridedSlice->name(), {1, 2, 3}},

View File

@ -80,10 +80,6 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
// Create tensor value.
if (info.shape.size() != 1 && info.shape[0] != 1) {
MS_LOG(EXCEPTION) << "Only support create scalar tensor value node!!!";
}
if (info.type == nullptr) {
MS_LOG(EXCEPTION) << "Data type can not be nullptr when creating scalar tensor!";
}

View File

@ -154,10 +154,14 @@ CNodePtr InplaceAssignBuilder::CreateCleanCompositeNode(const InplaceAssignerInf
// Create broadcast basic op.
auto dst_shape_vec = GetShape(op_info.op_node);
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
auto device_shape = GetDeviceShape(op_info.op_node);
auto shape_node = CreateScalarTensorValueNode<ShapeVector>(
{kOpFormat_DEFAULT, {SizeToLong(device_shape.size())}, TypeIdToType(kNumberTypeInt64)}, device_shape,
device_shape.size() * sizeof(int64_t));
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node, shape_node};
auto broadcast_to_node_inner =
CreateCNode(clean_inputs, new_sub_graph, {format, dst_shape_vec, GetType(op_info.op_node)});
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(op_info.op_node)), broadcast_to_node_inner);
// Makeup sub-graph.
new_sub_graph->set_output(broadcast_to_node_inner);

View File

@ -29,8 +29,8 @@ NodePtr GraphBuilder::Reshape(const NodePtr &input, const ShapeVector &shape) co
}
NodePtr GraphBuilder::BroadcastTo(const NodePtr &input, const ShapeVector &shape) const {
auto shape_value = MakeValue(shape);
return Emit("BroadcastTo", {input}, {{"shape", shape_value}});
auto shape_value = Tensor(shape);
return Emit("BroadcastTo", {input, shape_value});
}
NodePtr GraphBuilder::Gather(const NodePtr &param, const NodePtr &indice, int64_t axis, int64_t batch_dims) const {

View File

@ -501,7 +501,8 @@ class BroadcastTo(OpInfer):
"""BroadcastTo op."""
def supported_format(self):
return ["ND,ND"]
io_format = ["ND"] * len(self.input_desc)
return [",".join(io_format)]
def infer_shape(self):
"""Broadcast op keeps ND format, so the output shape will not be changed"""
@ -509,7 +510,7 @@ class BroadcastTo(OpInfer):
def infer_ori_shape(self):
shape = self.input_desc[0][ORI_SHAPE]
broad_shape = self.get_attr(SHAPE)
broad_shape = self.get_attr(SHAPE) if SHAPE in self.attr else self.input_desc[1][VALUE]
if len(broad_shape) < len(shape):
raise ValueError("The length of attr 'shape' must be >= the length of input shape, but got attr 'shape': "
"{}, input shape: {}".format(broad_shape, shape))
@ -523,6 +524,8 @@ class BroadcastTo(OpInfer):
self.output_desc[0][ORI_SHAPE] = out_shape
def post_process(self):
if not isinstance(self.op_desc.get(ATTR), list):
return
for item in self.op_desc[ATTR]:
if item[NAME] == SHAPE:
item["ori_value"] = item[VALUE]
@ -761,6 +764,9 @@ def update_akg_info(args, info_path, kernel_name=None):
# Update data format to DefaultFormat
convert_to_default_format(desc)
# GE backend must use old CCE
desc["backend"] = "GE"
return desc

View File

@ -1315,6 +1315,15 @@ def infer_value_for_BroadcastTo(x, shape):
return isinstance(x, (tuple, list)) and None in x
if shape is None or none_in_tuple_or_list(shape) or x is None:
return None
if isinstance(shape, (Tensor, Tensor_)):
validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
[mstype.int32, mstype.int64], "BroadcastTo")
shape = shape.asnumpy().tolist()
else:
validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
shape = list(shape)
np_data = np.broadcast_to(x.asnumpy(), shape)
if 0 in shape:
init_func = Zero()