fixed graph_kernel broadcastto bug
This commit is contained in:
parent
79b88f6dad
commit
c298f443ea
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit a235ba7365d87e4c884f93c8da934d7c313b61d4
|
||||
Subproject commit 749589da461f1c339e3ce94a07b638c99b632c45
|
|
@ -44,7 +44,7 @@ const std::set<std::string> &GetConvertInputAttrOps() {
|
|||
prim::kPrimCumSum->name(), prim::kPrimArgmin->name(), prim::kPrimArgmax->name(),
|
||||
prim::kPrimBiasAdd->name(), prim::kPrimBiasAddGrad->name(), prim::kPrimLayerNorm->name(),
|
||||
prim::kPrimLayerNormGrad->name(), prim::kPrimLogSoftmax->name(), prim::kPrimLogSoftmaxGrad->name(),
|
||||
prim::kPrimBroadcastTo->name(), prim::kPrimAdamWeightDecay->name(), prim::kPrimStridedSlice->name(),
|
||||
prim::kPrimStridedSlice->name(), prim::kPrimAdamWeightDecay->name(),
|
||||
};
|
||||
return convert_input_attr_ops;
|
||||
}
|
||||
|
|
|
@ -905,6 +905,11 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
|||
for (auto node : func_graph->GetOrderedCnodes()) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
auto sub_graph = GetCNodeFuncGraph(node);
|
||||
if (auto type = sub_graph->get_attr("composite_type")) {
|
||||
if (GetValue<std::string>(type) == "inplace_assign_builder") {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs = cnode->inputs();
|
||||
inner::LiteGraphPtr lg = GkUtils::AnfGraph2LiteGraph(sub_graph);
|
||||
|
|
|
@ -39,6 +39,7 @@ const std::unordered_map<std::string, HashSet<size_t>> &ValueDependOpUtils::GetO
|
|||
{prim::kPrimReduceSum->name(), {1}},
|
||||
{prim::kPrimTranspose->name(), {1}},
|
||||
{prim::kPrimTile->name(), {1}},
|
||||
{prim::kPrimBroadcastTo->name(), {1}},
|
||||
{prim::kPrimReduceMean->name(), {1}},
|
||||
{prim::kPrimSlice->name(), {1, 2}},
|
||||
{prim::kPrimStridedSlice->name(), {1, 2, 3}},
|
||||
|
|
|
@ -80,10 +80,6 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
|
|||
template <typename T>
|
||||
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
|
||||
// Create tensor value.
|
||||
if (info.shape.size() != 1 && info.shape[0] != 1) {
|
||||
MS_LOG(EXCEPTION) << "Only support create scalar tensor value node!!!";
|
||||
}
|
||||
|
||||
if (info.type == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Data type can not be nullptr when creating scalar tensor!";
|
||||
}
|
||||
|
|
|
@ -154,10 +154,14 @@ CNodePtr InplaceAssignBuilder::CreateCleanCompositeNode(const InplaceAssignerInf
|
|||
|
||||
// Create broadcast basic op.
|
||||
auto dst_shape_vec = GetShape(op_info.op_node);
|
||||
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node};
|
||||
auto device_shape = GetDeviceShape(op_info.op_node);
|
||||
auto shape_node = CreateScalarTensorValueNode<ShapeVector>(
|
||||
{kOpFormat_DEFAULT, {SizeToLong(device_shape.size())}, TypeIdToType(kNumberTypeInt64)}, device_shape,
|
||||
device_shape.size() * sizeof(int64_t));
|
||||
|
||||
AnfNodePtrList clean_inputs = {NewValueNode(prim::kPrimBroadcastTo), broadcast_input_node, shape_node};
|
||||
auto broadcast_to_node_inner =
|
||||
CreateCNode(clean_inputs, new_sub_graph, {format, dst_shape_vec, GetType(op_info.op_node)});
|
||||
SetNodeAttrSafely("shape", MakeValue(GetDeviceShape(op_info.op_node)), broadcast_to_node_inner);
|
||||
|
||||
// Makeup sub-graph.
|
||||
new_sub_graph->set_output(broadcast_to_node_inner);
|
||||
|
|
|
@ -29,8 +29,8 @@ NodePtr GraphBuilder::Reshape(const NodePtr &input, const ShapeVector &shape) co
|
|||
}
|
||||
|
||||
NodePtr GraphBuilder::BroadcastTo(const NodePtr &input, const ShapeVector &shape) const {
|
||||
auto shape_value = MakeValue(shape);
|
||||
return Emit("BroadcastTo", {input}, {{"shape", shape_value}});
|
||||
auto shape_value = Tensor(shape);
|
||||
return Emit("BroadcastTo", {input, shape_value});
|
||||
}
|
||||
|
||||
NodePtr GraphBuilder::Gather(const NodePtr ¶m, const NodePtr &indice, int64_t axis, int64_t batch_dims) const {
|
||||
|
|
|
@ -501,7 +501,8 @@ class BroadcastTo(OpInfer):
|
|||
"""BroadcastTo op."""
|
||||
|
||||
def supported_format(self):
|
||||
return ["ND,ND"]
|
||||
io_format = ["ND"] * len(self.input_desc)
|
||||
return [",".join(io_format)]
|
||||
|
||||
def infer_shape(self):
|
||||
"""Broadcast op keeps ND format, so the output shape will not be changed"""
|
||||
|
@ -509,7 +510,7 @@ class BroadcastTo(OpInfer):
|
|||
|
||||
def infer_ori_shape(self):
|
||||
shape = self.input_desc[0][ORI_SHAPE]
|
||||
broad_shape = self.get_attr(SHAPE)
|
||||
broad_shape = self.get_attr(SHAPE) if SHAPE in self.attr else self.input_desc[1][VALUE]
|
||||
if len(broad_shape) < len(shape):
|
||||
raise ValueError("The length of attr 'shape' must be >= the length of input shape, but got attr 'shape': "
|
||||
"{}, input shape: {}".format(broad_shape, shape))
|
||||
|
@ -523,6 +524,8 @@ class BroadcastTo(OpInfer):
|
|||
self.output_desc[0][ORI_SHAPE] = out_shape
|
||||
|
||||
def post_process(self):
|
||||
if not isinstance(self.op_desc.get(ATTR), list):
|
||||
return
|
||||
for item in self.op_desc[ATTR]:
|
||||
if item[NAME] == SHAPE:
|
||||
item["ori_value"] = item[VALUE]
|
||||
|
@ -761,6 +764,9 @@ def update_akg_info(args, info_path, kernel_name=None):
|
|||
# Update data format to DefaultFormat
|
||||
convert_to_default_format(desc)
|
||||
|
||||
# GE backend must use old CCE
|
||||
desc["backend"] = "GE"
|
||||
|
||||
return desc
|
||||
|
||||
|
||||
|
|
|
@ -1315,6 +1315,15 @@ def infer_value_for_BroadcastTo(x, shape):
|
|||
return isinstance(x, (tuple, list)) and None in x
|
||||
if shape is None or none_in_tuple_or_list(shape) or x is None:
|
||||
return None
|
||||
|
||||
if isinstance(shape, (Tensor, Tensor_)):
|
||||
validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
|
||||
[mstype.int32, mstype.int64], "BroadcastTo")
|
||||
shape = shape.asnumpy().tolist()
|
||||
else:
|
||||
validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
|
||||
shape = list(shape)
|
||||
|
||||
np_data = np.broadcast_to(x.asnumpy(), shape)
|
||||
if 0 in shape:
|
||||
init_func = Zero()
|
||||
|
|
Loading…
Reference in New Issue