[GraphKernel] clean code.

This commit is contained in:
chenlei_autodiff 2022-06-13 20:17:47 +08:00
parent dbec13ca4a
commit 6e4d2965db
21 changed files with 83 additions and 63 deletions

View File

@ -57,9 +57,10 @@ class EqualCount : public OpDesc {
auto dtype = input_x->type;
auto eql_val = gb.Equal(input_x, input_y);
auto cast_val = gb.Cast(eql_val, kNumberTypeFloat32);
std::vector<int64_t> axis;
for (size_t i = 0; i < input_x->shape.size(); ++i) {
axis.push_back(i);
auto shape_size = input_x->shape.size();
std::vector<int64_t> axis(shape_size);
for (size_t i = 0; i < shape_size; ++i) {
axis[i] = i;
}
auto result = gb.ReduceSum(cast_val, axis, false);
result = gb.Reshape(result, {1});

View File

@ -65,6 +65,6 @@ class OpDescRegister {
#define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt)
#define EXPANDER_OP_DESC_REGISTER(name, cls) \
const OpDescRegister UNIQUE_NAME(g_expander_opdesc_, __COUNTER__)( \
name, []() -> std::shared_ptr<OpDesc> { return std::make_shared<cls>(); })
name, []() noexcept -> std::shared_ptr<OpDesc> { return std::make_shared<cls>(); })
} // namespace mindspore::graphkernel::expanders
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_EXPANDER_FACTORY_H_

View File

@ -36,9 +36,9 @@ class ExpandDims : public OpDesc {
<< rank << "]";
}
if (x >= 0) {
(void)new_shape.insert(new_shape.begin() + x, 1LL);
(void)new_shape.insert(new_shape.cbegin() + x, 1LL);
} else {
(void)new_shape.insert(new_shape.begin() + (x + rank + 1), 1LL);
(void)new_shape.insert(new_shape.cbegin() + (x + rank + 1), 1LL);
}
}
return new_shape;
@ -75,7 +75,7 @@ class Squeeze : public OpDesc {
for (int64_t i = 0; i < ndim; i++) {
if (std::find(axis.begin(), axis.end(), i) == axis.end() &&
std::find(axis.begin(), axis.end(), i - ndim) == axis.end()) {
(void)new_shape.emplace_back(shape[i]);
(void)new_shape.emplace_back(shape[LongToSize(i)]);
}
}
}

View File

@ -68,7 +68,9 @@ class CheckAllFormatsSame : public Validator {
public:
bool Check(const OpDesc &e) override {
const auto &inputs_info = e.InputsInfo();
if (inputs_info.empty()) return true;
if (inputs_info.empty()) {
return true;
}
const auto &fmt_0 = inputs_info[0].format;
for (size_t i = 1; i < inputs_info.size(); i++) {
if (inputs_info[i].format != fmt_0) {

View File

@ -80,7 +80,9 @@ void FloatStatusAddNFusion::ProcessFloatStatusAddN(const FuncGraphPtr &main_grap
// Expand floatstatus to subgraph
for (size_t i = 1; i < addn->inputs().size(); i++) {
if (input_not_convert.count(i) > 0) continue;
if (input_not_convert.count(i) > 0) {
continue;
}
auto floatstatus = addn->input(i)->cast<CNodePtr>();
auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(floatstatus, false)->Run(floatstatus));
MS_EXCEPTION_IF_NULL(expand_fg);
@ -98,7 +100,9 @@ void FloatStatusAddNFusion::ProcessFloatStatusAddN(const FuncGraphPtr &main_grap
// Insert extra input(broadcast node output) to composite node, and make elemany inplace-assign to it.
for (size_t i = 1; i < addn->inputs().size(); i++) {
if (input_not_convert.count(i) > 0) continue;
if (input_not_convert.count(i) > 0) {
continue;
}
op_info = SubGraphSignleOutput(addn->input(i));
ProcessOriginCNode(addn->input(i), {{op_info, broadcast_to_node}});
}
@ -118,7 +122,9 @@ bool FloatStatusAddNFusion::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto changed = false;
if (!CanConvert()) return changed;
if (!CanConvert()) {
return changed;
}
auto nodes = TopoSort(func_graph->get_return());
for (auto node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
@ -129,7 +135,9 @@ bool FloatStatusAddNFusion::Run(const FuncGraphPtr &func_graph) {
bool pattern_match =
std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
[](const AnfNodePtr &anf_node) { return IsPrimitiveCNode(anf_node, prim::kPrimFloatStatus); });
if (!pattern_match) continue;
if (!pattern_match) {
continue;
}
ProcessFloatStatusAddN(func_graph, cnode, mng);
changed = true;
}

View File

@ -35,7 +35,9 @@ void GetTopoValidNodes(const FuncGraphPtr &func_graph, CNodePtrList *topo_valid_
MS_EXCEPTION_IF_NULL(topo_valid_nodes);
auto nodes = TopoSort(func_graph->get_return());
for (auto &node : nodes) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) continue;
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
topo_valid_nodes->push_back(cnode);
@ -121,11 +123,11 @@ void SafeSplitSchemer::GroupReturnNode() {
MS_EXCEPTION_IF_NULL(ret_node);
auto output = func_graph_->output();
MS_EXCEPTION_IF_NULL(output);
// set the make_tuple node to a new group.
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
// set the make_tuple node to a new group.
auto group_id = split_plan_.size();
(void)split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
need_inline_.push_back(1);
(void)split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
node_group_[output] = group_id;
node_group_[ret_node] = group_id;
} else {
@ -139,7 +141,6 @@ void GraphKernelBuild::Init() {
// Init KernelMeta.
if (bin_map_ == nullptr) {
bin_map_ = kernel::KernelMeta::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map_);
if (!bin_map_->initialized()) {
bin_map_->Initialize();
}
@ -175,7 +176,7 @@ bool GraphKernelBuild::Process(const FuncGraphPtr &func_graph, int iter) {
// Update cache after compiling. Nodes that still not have compile cache means they compiled failed.
auto remaining_nodes = CollectNotCachedNodes(need_compile_nodes);
// Split nodes that compile failed.
changed = SplitNodes(func_graph, remaining_nodes);
changed = SplitNodes(remaining_nodes);
return changed;
}
@ -201,7 +202,7 @@ kernel::JsonNodePair GraphKernelBuild::CollectNode(const AnfNodePtr &node) const
return std::make_pair(akg_kernel_json_generator, node);
}
void GraphKernelBuild::CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) {
void GraphKernelBuild::CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) const {
if (func_graph == nullptr) {
return;
}
@ -226,9 +227,13 @@ std::vector<kernel::JsonNodePair> GraphKernelBuild::CollectNotCachedNodes(
MS_EXCEPTION_IF_NULL(kernel_builder_);
std::vector<kernel::JsonNodePair> res;
for (const auto &[json_generator, node] : nodes) {
if (node == nullptr) continue;
if (node == nullptr) {
continue;
}
// Skip node that already set kernel mod(created from compile cache).
if (AnfAlgo::GetKernelMod(node) != nullptr) continue;
if (AnfAlgo::GetKernelMod(node) != nullptr) {
continue;
}
const auto &kernel_name = json_generator.kernel_name();
// Skip node that already has cache.
if (kernel_pack_.find(kernel_name) != kernel_pack_.end()) {
@ -279,7 +284,7 @@ void GraphKernelBuild::ParallelBuild(const std::vector<kernel::JsonNodePair> &no
}
}
bool GraphKernelBuild::SplitNodes(const FuncGraphPtr &func_graph, const std::vector<kernel::JsonNodePair> &nodes) {
bool GraphKernelBuild::SplitNodes(const std::vector<kernel::JsonNodePair> &nodes) {
bool result = false;
std::unordered_set<std::string> kernel_names;
for (const auto &[json_generator, node] : nodes) {

View File

@ -75,13 +75,13 @@ class GraphKernelBuild : public opt::Pass {
bool Process(const FuncGraphPtr &func_graph, int iter);
kernel::JsonNodePair CollectNode(const AnfNodePtr &node) const;
// Collect graph kernel nodes in main graph.
void CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes);
void CollectNodes(const FuncGraphPtr &func_graph, std::vector<kernel::JsonNodePair> *nodes) const;
// Collect graph kernel nodes that do not have compile cache, which means these nodes need to be compiled.
std::vector<kernel::JsonNodePair> CollectNotCachedNodes(const std::vector<kernel::JsonNodePair> &nodes);
// Parallel compiling.
void ParallelBuild(const std::vector<kernel::JsonNodePair> &nodes);
// Split nodes that compiled failed.
bool SplitNodes(const FuncGraphPtr &func_graph, const std::vector<kernel::JsonNodePair> &nodes);
bool SplitNodes(const std::vector<kernel::JsonNodePair> &nodes);
SafeGraphKernelSplitter splitter_; // used to split nodes that compile failed
kernel::KernelMeta *bin_map_{nullptr};

View File

@ -92,7 +92,7 @@ class FlagRegister {
template <typename T>
void AddFlag(const std::string &flag_name, T *flag_var, T default_value = T()) const {
auto iter = flag_map_.find(flag_name);
const auto iter = flag_map_.find(flag_name);
if (iter != flag_map_.end()) {
T var;
bool ret = ParseValue(iter->second, &var);
@ -216,7 +216,7 @@ void GraphKernelFlags::CheckSupport() const {
void GraphKernelFlags::Refresh() {
auto flag_map = ParseFlags(flags_cache_);
RegisterFlags(&flag_map);
for (auto &item : flag_map) {
for (const auto &item : flag_map) {
MS_LOG(WARNING) << "Unknown flag: " << item.first;
}
if (!flag_map.empty()) {

View File

@ -50,6 +50,7 @@ class GraphKernelFlags {
GraphKernelFlags(const GraphKernelFlags &flags) = delete;
GraphKernelFlags(GraphKernelFlags &&flags) = delete;
GraphKernelFlags &operator=(const GraphKernelFlags &flags) = delete;
GraphKernelFlags &operator=(GraphKernelFlags &&flags) = delete;
~GraphKernelFlags() = default;
/**

View File

@ -35,7 +35,7 @@ class OpRegister {
#define OP_REGISTER(name, cls) \
static_assert(std::is_base_of<PrimOp, cls>::value, " should be base of PrimOp"); \
static const OpRegister UNIQUE_NAME(g_graphkernel_op, __COUNTER__)( \
name, [](const std::string &op) -> PrimOpPtr { return std::make_shared<cls>(op); })
name, [](const std::string &op) noexcept -> PrimOpPtr { return std::make_shared<cls>(op); })
} // namespace
/* All nodes supported by GraphKernel are listed below. */

View File

@ -44,6 +44,7 @@ class OpRegistry {
OpRegistry(const OpRegistry &) = delete;
OpRegistry(const OpRegistry &&) = delete;
OpRegistry &operator=(const OpRegistry &) = delete;
OpRegistry &operator=(const OpRegistry &&) = delete;
mindspore::HashMap<std::string, CreatorFunc> creators;
};

View File

@ -18,14 +18,12 @@ from mindspore._extends.graph_kernel.model import model_builder as builder
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
class Expander:
class Expander(metaclass=ABCMeta):
"""
Expander is the base class of expanders.
The method `_expand` should be overridden to implement the operator detail.
"""
__metaclass__ = ABCMeta
def __init__(self, expand_info):
self.name = expand_info["name"]
self.inputs = expand_info["input_desc"]

View File

@ -18,16 +18,6 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
M_ALIGN = 32
N_ALIGN = 32
K_ALIGN = 16
K_LIMIT = 800
MNK_LIMIT = 3 * (10 ** 10)
N0_CHANNEL_ALIGN = 32
N1_CHANNEL_ALIGN = 32
C_CHANNEL_ALIGN = 16
OUT_NHW_ALIGN = 128
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.NHWC)
@ -45,6 +35,15 @@ class Conv2D(Expander):
C channel of inputs > 8.
output N*H*W are multiplies of 128.
"""
M_ALIGN = 32
N_ALIGN = 32
K_ALIGN = 16
K_LIMIT = 800
MNK_LIMIT = 3 * (10 ** 10)
N0_CHANNEL_ALIGN = 32
N1_CHANNEL_ALIGN = 32
C_CHANNEL_ALIGN = 16
OUT_NHW_ALIGN = 128
def __init__(self, expand_info):
super().__init__(expand_info)

View File

@ -53,6 +53,7 @@ class MatMul(Expander):
self.transpose_b = self.attrs['transpose_b']
self.left_format = self.attrs['left_format']
self.right_format = self.attrs['right_format']
def transpose(shape):
trans_shape = list(shape)
trans_shape[-2] = shape[-1]

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ===========================================================================
"""Cost model for parallel fusion"""
from __future__ import division
from .model import PrimLib

View File

@ -801,7 +801,7 @@ class GraphSplitByPattern:
while stack:
op = stack.pop()
if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
return None
return []
ops.append(op)
for t in op.inputs:
if t.op in area.ops:
@ -1153,7 +1153,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _broadcast_onehot(dom, fwd=True):
"""Fuse rule for OneHot."""
if dom.dom_op().prim != "OneHot":
return None
return []
fused = []
neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
@ -1168,7 +1168,7 @@ class GraphSplitGpu(GraphSplitByPattern):
def _elemwise_elemany(dom):
"""Fuse rule for elemany."""
if dom.dom_op().prim != "ElemAny":
return None
return []
fused = []
for a, r in dom.in_relations.items():

View File

@ -120,6 +120,23 @@ class CompositeGraph:
self.desc = None
self.tensors = {} # name : Tensor
@staticmethod
def add_stitch_info(subgraph, desc):
"""add stitch info to desc"""
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
if subgraph.stitch_info.stitch_atomic_ops:
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
desc['buffer_stitch'] = buffer_stitch
return desc
@staticmethod
def add_recompute_ops(subgraph, desc):
"""add recompute ops to desc"""
if subgraph.recompute_ops:
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc
def refine(self):
"""Refine Graph"""
AlignShape().visit_graph(self.graph)
@ -150,7 +167,8 @@ class CompositeGraph:
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
inputs = [self.tensors.get(d['tensor_name'], None) for x in op['input_desc']
for d in x if 'value' not in d]
out_desc = op['output_desc']
name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
@ -168,21 +186,6 @@ class CompositeGraph:
self.graph = builder.get()[0]
self.desc = desc
def add_stitch_info(self, subgraph, desc):
"""add stitch info to desc"""
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
if subgraph.stitch_info.stitch_atomic_ops:
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
desc['buffer_stitch'] = buffer_stitch
return desc
def add_recompute_ops(self, subgraph, desc):
"""add recompute ops to desc"""
if subgraph.recompute_ops:
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc
def _pre_dump(self, outputs):
"""restore name to before load"""
inplace_assign = {} # y_name, output_name
@ -205,7 +208,7 @@ class CompositeGraph:
def dump_output(t):
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
z = inplace_assign_z if inplace_assign_z is not None else self.tensors.get(t.name, None)
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}

View File

@ -55,8 +55,9 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
"""Dump split info as text"""
if not flags.get("dump_as_text", False):
return
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
graph_kernel_dump_path = "graph_kernel_dump"
utils.create_dir(graph_kernel_dump_path)
filename = os.path.join(graph_kernel_dump_path, "graph_kernel_split_mode.txt")
with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
f.write("********** main graph: {} **********\n".format(graph_desc.name))
f.write("input json:\n{}\n".format(graph_json))

View File

@ -15,8 +15,6 @@
"""GraphKernel utils"""
import os
GRAPH_KERNEL_DUMP_PATH = "graph_kernel_dump"
def create_dir(pathname):
"""Try to create directory"""

View File

@ -130,6 +130,7 @@ class AkgBuilder():
def __init__(self, platform):
self.platform = platform
self.attrs = None
self.akg_processor = None
def create(self, process_num, waitime):
""" Create akg processor"""