forked from mindspore-Ecosystem/mindspore
commit
a24ff36d9c
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Cost model splitter"""
|
||||
import os
|
||||
from functools import reduce
|
||||
from .model import PrimLib, Graph, Tensor
|
||||
|
||||
|
@ -23,12 +24,19 @@ class GraphSplitByPattern:
|
|||
MODE_BASIC = 1
|
||||
MODE_COMPOSITE = 2
|
||||
|
||||
class StitchInfo:
|
||||
"""StitchInfo"""
|
||||
def __init__(self):
|
||||
self.stitch_ops = set()
|
||||
self.stitch_atomic_ops = set()
|
||||
|
||||
def __init__(self, init_op, is_output):
|
||||
self.pattern = PrimLib.iter_type(init_op)
|
||||
self.ops = [init_op]
|
||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.mode = None
|
||||
self.stitch_info = self.StitchInfo()
|
||||
self.is_output = is_output
|
||||
self.output_excluded = set()
|
||||
if self.pattern == PrimLib.REDUCE:
|
||||
|
@ -69,6 +77,12 @@ class GraphSplitByPattern:
|
|||
for input_area, r in self.in_relations.items():
|
||||
input_area.out_relations[self] = r
|
||||
|
||||
def update_stitch_info(self, stitch_info):
|
||||
if stitch_info.stitch_ops:
|
||||
self.stitch_info.stitch_ops.update(stitch_info.stitch_ops)
|
||||
if stitch_info.stitch_atomic_ops:
|
||||
self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops)
|
||||
|
||||
def fuse(self, area):
|
||||
"""Fuse `area` to `self`"""
|
||||
def _update_relation(relations, a, r):
|
||||
|
@ -107,6 +121,7 @@ class GraphSplitByPattern:
|
|||
self.is_output = True
|
||||
if area.output_excluded:
|
||||
self.output_excluded.update(area.output_excluded)
|
||||
self.update_stitch_info(area.stitch_info)
|
||||
|
||||
def check_circle(self, to):
|
||||
"""Check circle. It returns false if circle exists"""
|
||||
|
@ -181,10 +196,25 @@ class GraphSplitByPattern:
|
|||
graphmodes = []
|
||||
for i, area in enumerate(self.areas):
|
||||
area.ops.sort(key=lambda op: ids[op])
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops))
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info))
|
||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||
return subgraphs, graphmodes
|
||||
|
||||
def dump_subgraphs(self, subgraphs):
|
||||
"""Dump subgraphs"""
|
||||
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
|
||||
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
|
||||
for i, sub in enumerate(subgraphs):
|
||||
subgraphs_str += str("============") + str(i) + "\n"
|
||||
subgraphs_str += str(sub)
|
||||
dirname = 'subgraphs'
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
graphname = self.graph.name
|
||||
filename = dirname + '/' + graphname + '.log'
|
||||
with open(filename, 'w') as f:
|
||||
f.write(subgraphs_str)
|
||||
|
||||
def split(self):
|
||||
"""Split graph by pattern"""
|
||||
self.do_split()
|
||||
|
@ -192,6 +222,7 @@ class GraphSplitByPattern:
|
|||
# Note: after this function, the input output relation is not maintained.
|
||||
self.split_output_reshapes()
|
||||
subgraphs, graphmodes = self.to_subgraphs()
|
||||
self.dump_subgraphs(subgraphs)
|
||||
return subgraphs, graphmodes
|
||||
|
||||
def split_output_reshapes(self):
|
||||
|
@ -362,15 +393,25 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
return reduce_size >= 1024
|
||||
return True
|
||||
|
||||
def _reduce_nums(ops):
|
||||
count = 0
|
||||
for op in ops:
|
||||
if op.prim.startswith('Reduce'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _reduce_output(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _reduce_nums(dom.ops) > 1:
|
||||
return None
|
||||
if _is_atomic_add_available(dom):
|
||||
return None
|
||||
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
|
||||
# excluded large size all reduce
|
||||
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
||||
return None
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
||||
|
@ -378,6 +419,24 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _reduce_stitch(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _tensor_size(dom.ops[0].output) == 1:
|
||||
return None
|
||||
if _tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
|
||||
return None
|
||||
|
||||
fused = []
|
||||
for a, r in dom.out_relations.items():
|
||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_circle(a):
|
||||
if _reduce_nums(a.ops) < 2:
|
||||
# softmax
|
||||
if len(a.ops) > 4 and len(a.ops[0].inputs[0].shape) == 4:
|
||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||
fused.append(a)
|
||||
return fused, False
|
||||
|
||||
def _transpose(dom):
|
||||
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
|
||||
return None
|
||||
|
@ -398,6 +457,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
|||
changed = self.fuse(_broadcast_width) or changed
|
||||
if use_poly_reduce:
|
||||
changed = self.fuse(_reduce_output) or changed
|
||||
changed = self.fuse(_reduce_stitch) or changed
|
||||
self.fuse(_transpose)
|
||||
|
||||
class GraphSplitAscend(GraphSplitByPattern):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -310,11 +310,12 @@ class Operator:
|
|||
class Graph:
|
||||
"""Graph"""
|
||||
|
||||
def __init__(self, name, ops):
|
||||
def __init__(self, name, ops, stitch_info=None):
|
||||
self.name = name
|
||||
self.ops = ops # in topo order, can not use set
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
self.stitch_info = stitch_info
|
||||
|
||||
def set_processor(self, processor):
|
||||
"""Set processor"""
|
||||
|
@ -372,6 +373,12 @@ class Graph:
|
|||
out_str = ', '.join([repr(t) for t in outputs])
|
||||
lines = []
|
||||
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
|
||||
if self.stitch_info:
|
||||
if self.stitch_info.stitch_ops:
|
||||
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
|
||||
if self.stitch_info.stitch_atomic_ops:
|
||||
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
|
||||
|
||||
for op in self.ops:
|
||||
lines.append(' ' + str(op))
|
||||
lines.append('}')
|
||||
|
@ -405,12 +412,20 @@ class Graph:
|
|||
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
|
||||
'tensor_name': t.name, 'format': t.data_format}])
|
||||
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
|
||||
'tensor_name': op.output.name, 'format': t.data_format}]
|
||||
'tensor_name': op.output.name, 'format': op.output.data_format}]
|
||||
op_desc.append({'attr': attrs, 'impl_path': '',
|
||||
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
|
||||
|
||||
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
|
||||
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
|
||||
'platform': 'AKG', 'process': self.processor}
|
||||
|
||||
if self.stitch_info and self.stitch_info.stitch_ops:
|
||||
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
|
||||
if self.stitch_info.stitch_atomic_ops:
|
||||
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
|
||||
graph_desc['buffer_stitch'] = buffer_stitch
|
||||
|
||||
return graph_desc
|
||||
|
||||
|
||||
|
|
|
@ -313,6 +313,14 @@ class CompositeGraph:
|
|||
self.graph = builder.get()[0]
|
||||
self.desc = desc
|
||||
|
||||
def add_stitch_info(self, subgraph, desc):
|
||||
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
|
||||
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
|
||||
if subgraph.stitch_info.stitch_atomic_ops:
|
||||
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
|
||||
desc['buffer_stitch'] = buffer_stitch
|
||||
return desc
|
||||
|
||||
def dump(self, subgraph):
|
||||
"""Dump Graph to json"""
|
||||
desc = {}
|
||||
|
@ -368,6 +376,8 @@ class CompositeGraph:
|
|||
desc[key] = subgraph.name
|
||||
else:
|
||||
desc[key] = self.desc[key]
|
||||
|
||||
desc = self.add_stitch_info(subgraph, desc)
|
||||
return desc
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -414,6 +414,35 @@ FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_js
|
|||
return DecodeFusedNodes(kernel_json);
|
||||
}
|
||||
|
||||
StitchInfo AkgKernelJsonDecoder::GetStitchInfo(const nlohmann::json &kernel_json) {
|
||||
StitchInfo info;
|
||||
if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
|
||||
nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
|
||||
if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
|
||||
info.stitch_ops = stitch_ops;
|
||||
}
|
||||
if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
|
||||
std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
|
||||
info.stitch_atomic_ops = stitch_atomic_ops;
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
|
||||
std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
|
||||
if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
|
||||
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
|
||||
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
|
||||
}
|
||||
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
|
||||
info.stitch_atomic_ops.end()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
|
||||
}
|
||||
}
|
||||
|
||||
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||
const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
AnfNodePtrList *res_graphs) {
|
||||
|
@ -425,6 +454,7 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
|||
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
||||
return false;
|
||||
}
|
||||
StitchInfo info = GetStitchInfo(kernel_json);
|
||||
for (const auto &op_desc : op_node_descs) {
|
||||
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
||||
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
||||
|
@ -436,7 +466,9 @@ bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
|||
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
||||
return false;
|
||||
}
|
||||
res_graphs->push_back(address_node_map.at(ptr_address));
|
||||
auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
|
||||
SetStitchAttr(op_desc, info, node);
|
||||
res_graphs->push_back(node);
|
||||
}
|
||||
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
||||
return true;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct StitchInfo {
|
||||
std::vector<std::string> stitch_ops;
|
||||
std::vector<std::string> stitch_atomic_ops;
|
||||
};
|
||||
class AkgKernelJsonDecoder {
|
||||
public:
|
||||
AkgKernelJsonDecoder() { nodes_map_.clear(); }
|
||||
|
@ -40,6 +44,8 @@ class AkgKernelJsonDecoder {
|
|||
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
||||
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
||||
AnfNodePtr DecodeOutput(const std::vector<nlohmann::json> &output_descs, const FuncGraphPtr &func_graph);
|
||||
StitchInfo GetStitchInfo(const nlohmann::json &kernel_json);
|
||||
void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node);
|
||||
std::map<std::string, AnfNodePtr> nodes_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -580,6 +580,31 @@ void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &process
|
|||
(*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
|
||||
}
|
||||
|
||||
void AkgKernelJsonGenerator::GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map,
|
||||
nlohmann::json *kernel_json) {
|
||||
std::vector<std::string> stitchs;
|
||||
for (auto const &anf_node : anf_nodes) {
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStitch, anf_node->cast<CNodePtr>()) &&
|
||||
AnfAlgo::GetNodeAttr<std::string>(anf_node, kAttrStitch) == "common") {
|
||||
auto name = GetTensorName((*node_json_map)[anf_node], kJsonKeyOutputDesc, {0, 0});
|
||||
if (std::find(stitchs.begin(), stitchs.end(), name) == stitchs.end()) {
|
||||
stitchs.emplace_back(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!stitchs.empty()) {
|
||||
std::vector<nlohmann::json> v;
|
||||
for (auto &s : stitchs) {
|
||||
std::vector<std::string> t;
|
||||
t.emplace_back(s);
|
||||
v.emplace_back(t);
|
||||
}
|
||||
nlohmann::json stitch_json;
|
||||
stitch_json[kJsonKeyStitchOp] = v;
|
||||
(*kernel_json)[kJsonKeyBufferStitch] = stitch_json;
|
||||
}
|
||||
}
|
||||
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) {
|
||||
|
@ -637,6 +662,8 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
(*kernel_json)[kJsonKeyComposite] = true;
|
||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
|
||||
|
||||
GenStitchJson(anf_nodes, &node_json_map, kernel_json);
|
||||
|
||||
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
|
||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||
return false;
|
||||
|
|
|
@ -54,6 +54,9 @@ constexpr auto kJsonKeyParallelFusion = "parallel_fusion";
|
|||
constexpr auto kJsonKeyFusionType = "fusion_type";
|
||||
constexpr auto kJsonKeySubGraph = "sub_graph";
|
||||
constexpr auto kJsonKeyCoreNum = "core_num";
|
||||
constexpr auto kJsonKeyBufferStitch = "buffer_stitch";
|
||||
constexpr auto kJsonKeyStitchOp = "stitch_op";
|
||||
constexpr auto kJsonKeyStitchAtomicOp = "stitch_atomic_op";
|
||||
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
||||
|
@ -98,6 +101,8 @@ class AkgKernelJsonGenerator {
|
|||
void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes, const OpAttrPtr &op_attr,
|
||||
nlohmann::json *attr_json, const ValuePtr &attr_value);
|
||||
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json);
|
||||
void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map,
|
||||
nlohmann::json *kernel_json);
|
||||
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size, std::vector<size_t> *output_size);
|
||||
bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
|
||||
void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
|
@ -28,21 +29,24 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class AtomicCleanInsertter : public Pass {
|
||||
public:
|
||||
AtomicCleanInsertter() : Pass("atomic_clean") {}
|
||||
explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {}
|
||||
~AtomicCleanInsertter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
virtual bool Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
|
||||
protected:
|
||||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
|
||||
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
|
||||
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &user_node, int index);
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
|
||||
CNodePtr atomic_add_node_{nullptr};
|
||||
|
||||
private:
|
||||
bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
|
||||
void CorrectAbstract(const AnfNodePtr &composite_node);
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
|
||||
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
|
||||
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
|
||||
std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph);
|
||||
|
@ -55,7 +59,6 @@ class AtomicCleanInsertter : public Pass {
|
|||
bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
|
||||
CNodePtr atomic_add_node_{nullptr};
|
||||
size_t reduce_real_output_index_{0};
|
||||
size_t real_output_num_{0};
|
||||
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_;
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input) {
|
||||
// Change kernel build info.
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(composite_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const auto &origin_kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
auto origin_inputs_format = origin_kernel_build_info->GetAllInputFormats();
|
||||
auto origin_outputs_format = origin_kernel_build_info->GetAllOutputFormats();
|
||||
auto origin_inputs_type = origin_kernel_build_info->GetAllInputDeviceTypes();
|
||||
auto origin_outputs_type = origin_kernel_build_info->GetAllOutputDeviceTypes();
|
||||
auto origin_processor = origin_kernel_build_info->processor();
|
||||
|
||||
std::vector<std::string> &new_inputs_format = origin_inputs_format;
|
||||
std::vector<TypeId> &new_inputs_type = origin_inputs_type;
|
||||
std::vector<std::string> new_outputs_format;
|
||||
std::vector<TypeId> new_outputs_type;
|
||||
for (size_t i = 0; i < origin_outputs_format.size(); ++i) {
|
||||
new_outputs_format.push_back(origin_outputs_format[i]);
|
||||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder new_info_builder;
|
||||
new_info_builder.SetInputsFormat(new_inputs_format);
|
||||
new_info_builder.SetInputsDeviceType(new_inputs_type);
|
||||
new_info_builder.SetOutputsFormat(new_outputs_format);
|
||||
new_info_builder.SetOutputsDeviceType(new_outputs_type);
|
||||
new_info_builder.SetProcessor(origin_processor);
|
||||
new_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
new_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto new_selected_info = new_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(new_selected_info, composite_node.get());
|
||||
}
|
||||
|
||||
CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph,
|
||||
const AnfNodePtr &new_parameter) {
|
||||
// add inplaceassign
|
||||
AnfNodePtr out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
|
||||
auto inplace_assign_node =
|
||||
CreateCNode({NewValueNode(std::make_shared<Primitive>("InplaceAssign")), new_parameter, atomic_add_node_, out_node},
|
||||
sub_graph, {.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
|
||||
AnfAlgo::SetNodeAttr("fake_output", MakeValue(true), inplace_assign_node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrStitch, atomic_add_node_);
|
||||
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), inplace_assign_node);
|
||||
return inplace_assign_node;
|
||||
}
|
||||
|
||||
void StitchAtomicCleanInsertter::ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
|
||||
// add input
|
||||
auto inputs = composite_node->cast<CNodePtr>()->inputs();
|
||||
inputs.push_back(new_input);
|
||||
composite_node->cast<CNodePtr>()->set_inputs(inputs);
|
||||
|
||||
// add parameter
|
||||
auto parameter = sub_graph->add_parameter();
|
||||
parameter->set_abstract(new_input->abstract());
|
||||
parameter->set_kernel_info(new_input->kernel_info_ptr());
|
||||
|
||||
auto inplace_assign = CreateInplaceAssignNodeAndCorrectReturn(sub_graph, parameter);
|
||||
|
||||
// Replace atomic ReduceSum's user with atomic clean output, and add depend op after inplaceassign to avoid
|
||||
// elimination.
|
||||
std::vector<std::pair<AnfNodePtr, int>> reduce_user_nodes = FindInnerCNodeUsers(stitch_node_, atomic_add_node_);
|
||||
bool connected = false;
|
||||
for (const auto &[user_node, index] : reduce_user_nodes) {
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(index, parameter);
|
||||
if (!connected) {
|
||||
std::vector<std::pair<AnfNodePtr, int>> user_user = FindInnerCNodeUsers(stitch_node_, user_cnode);
|
||||
if (!user_user.empty()) {
|
||||
auto pair = user_user[0];
|
||||
AddDepend(sub_graph, user_cnode, inplace_assign, pair.first, pair.second);
|
||||
}
|
||||
connected = true;
|
||||
}
|
||||
CorrectKernelBuildInfo(composite_node, new_input);
|
||||
}
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
|
||||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNodeUsers(const AnfNodePtr &inner_node,
|
||||
const CNodePtr &target) {
|
||||
auto node = inner_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
sub_graph->set_manager(mng_sub);
|
||||
}
|
||||
std::vector<std::pair<AnfNodePtr, int>> inner_user_nodes;
|
||||
auto users = mng_sub->node_users()[target];
|
||||
std::transform(users.cbegin(), users.cend(), std::back_inserter(inner_user_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &pair) { return pair; });
|
||||
return inner_user_nodes;
|
||||
}
|
||||
|
||||
bool StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!AnfAlgo::IsGraphKernel(anf_node)) return false;
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
AnfNodePtrList kernel_nodes;
|
||||
kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
|
||||
for (auto &n : kernel_nodes) {
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
|
||||
AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
|
||||
atomic_add_node_ = n->cast<CNodePtr>();
|
||||
stitch_node_ = anf_node;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StitchAtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
auto topo_nodes = TopoSort(kernel_graph->get_return());
|
||||
for (const auto &node : topo_nodes) {
|
||||
// if stitch attr exists, add atomic clean op depends on the attr
|
||||
if (IsStitchWithAtomic(node)) {
|
||||
InsertAtomicClean(kernel_graph, node, mng);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class StitchAtomicCleanInsertter : public AtomicCleanInsertter {
|
||||
public:
|
||||
StitchAtomicCleanInsertter() : AtomicCleanInsertter("stitch_atomic_clean") {}
|
||||
~StitchAtomicCleanInsertter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
|
||||
CNodePtr CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node, const CNodePtr &target);
|
||||
void ProcessOriginCNode(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng);
|
||||
bool IsStitchWithAtomic(const AnfNodePtr &anf_node);
|
||||
|
||||
AnfNodePtr stitch_node_{nullptr};
|
||||
};
|
||||
using StitchAtomicCleanInsertterPtr = std::shared_ptr<StitchAtomicCleanInsertter>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -717,6 +717,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimMinimumGrad,
|
||||
prim::kPrimGkDropout,
|
||||
prim::kPrimDropoutGrad,
|
||||
prim::kPrimSoftMax,
|
||||
#endif
|
||||
};
|
||||
return expand_ops;
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "backend/optimizer/gpu/add_relu_v2_fusion.h"
|
||||
#include "backend/optimizer/gpu/add_relu_grad_v2_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
|
||||
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
|
||||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/clean_all_in_once.h"
|
||||
|
@ -201,6 +202,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
// will be exposed, use GetitemTuple Pass to delete them.
|
||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||
pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>());
|
||||
pm->AddPass(std::make_shared<opt::StitchAtomicCleanInsertter>());
|
||||
pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion.
|
||||
pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7)));
|
||||
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
|
||||
|
|
|
@ -384,6 +384,7 @@ constexpr auto kAttrIsGrad = "is_grad";
|
|||
constexpr auto kAttrRecompute = "recompute";
|
||||
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
||||
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
|
||||
constexpr auto kAttrStitch = "stitch";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -132,7 +132,7 @@ inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
|
|||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("SoftMax");
|
||||
inline const PrimitivePtr kPrimSoftMax = std::make_shared<Primitive>("Softmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
|
||||
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
|
||||
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
|
||||
|
|
Loading…
Reference in New Issue