!40777 use c++ split on Ascend

Merge pull request !40777 from looop5/split_ascend_commit
This commit is contained in:
i-robot 2022-08-29 01:58:47 +00:00 committed by Gitee
commit 63c2d7b9e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 199 additions and 6 deletions

View File

@ -370,8 +370,7 @@ class CostModelSplitSchemer : public SplitSchemer {
};
std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
// default use c++ split model for CPU target.
if (processor != kCPUDevice) {
if (processor != kCPUDevice && processor != kAscendDevice) {
MS_LOG(DEBUG) << "use py split model";
return std::make_shared<CostModelSplitSchemer>();
} else {

View File

@ -0,0 +1,148 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/adapter/split_model_ascend.h"
#include <memory>
#include <string>
#include "utils/ms_context.h"
namespace mindspore::graphkernel::inner {
namespace ascend {
constexpr size_t kReduceFusionDepth = 10;
constexpr size_t kBroadcastFusionDepth = 6;
class FuseReduceBwd : public FusePattern {
public:
FuseReduceBwd() : FusePattern("reduce_bwd") { direction_ = FuseDirection::BACKWARD; }
~FuseReduceBwd() = default;
protected:
bool Check(const AreaPtr &dom) override { return dom->IsAlive() && dom->pattern() == NodePattern::REDUCE; }
bool Match(const AreaPtr &dom) override {
auto op_attrs = dom->dom()->attrs();
if (op_attrs.find("reduce_output_fuse") == op_attrs.end()) {
return false;
}
for (auto &[a, r] : dom->users_with_relation()) {
if (a->pattern() <= NodePattern::BROADCAST && r == EdgeRelation::INJECTIVE && !HasCircle(dom, a)) {
(void)fused_areas_.emplace_back(a);
}
}
return !fused_areas_.empty();
}
};
class FuseMatMul : public FusePattern {
public:
FuseMatMul() : FusePattern("matmul_depth") {}
~FuseMatMul() = default;
protected:
bool Check(const AreaPtr &dom) override {
return dom->IsAlive() && (dom->dom()->op() == kMatMulOpName || dom->dom()->op() == kBatchMatMulOpName);
}
bool Match(const AreaPtr &dom) override {
auto dom_name = dom->dom()->op();
for (auto &a : dom->users()) {
if (!a->IsAlive()) {
continue;
}
auto user_name = a->dom()->op();
if (((dom_name == kMatMulOpName &&
(user_name == kAddNOpName || user_name == kTensorAddOpName || user_name == kCastOpName)) ||
(dom_name == kBatchMatMulOpName && a->pattern() == NodePattern::ELEMWISE)) &&
!HasCircle(dom, a)) {
(void)fused_areas_.emplace_back(a);
}
}
direction_ = FuseDirection::BACKWARD;
return !fused_areas_.empty();
}
};
class FuseTransdata : public FusePattern {
public:
FuseTransdata() : FusePattern("transdata") {}
~FuseTransdata() = default;
protected:
bool Check(const AreaPtr &dom) override { return dom->IsAlive() && dom->dom()->op() == kTransDataOpName; }
bool Match(const AreaPtr &dom) override {
for (auto &a : dom->inputs()) {
if (a->IsAlive() && Supported(dom, a) && !HasCircle(a, dom)) {
(void)fused_areas_.emplace_back(a);
}
}
return !fused_areas_.empty();
}
private:
bool NeedPad(const DShape &in_shape, const DShape &out_shape) const {
const size_t min_rank = 2;
const int64_t block_sz = 16;
return !(in_shape.size() >= min_rank && out_shape.size() >= min_rank &&
in_shape[in_shape.size() - kIndex1] == block_sz && in_shape[in_shape.size() - kIndex2] == block_sz &&
out_shape[out_shape.size() - kIndex1] == block_sz && out_shape[out_shape.size() - kIndex2] == block_sz);
}
bool Supported(const AreaPtr &dom, const AreaPtr &a) const {
if (dom->size() != 1 || dom->dom()->inputs().empty() || NeedPad(dom->dom()->input(0)->shape, dom->dom()->shape)) {
return false;
}
if (a->dom()->op() == kMatMulOpName) {
return true;
}
if (a->pattern() > NodePattern::BROADCAST) {
return false;
}
auto op_attrs = dom->dom()->attrs();
if (op_attrs.find(kAttrSrcFormat) == op_attrs.end() || op_attrs.find(kAttrDstFormat) == op_attrs.end()) {
MS_LOG(ERROR) << "For '" << dom->dom()->op() << "', can not find the attr '" << kAttrSrcFormat << "' or '"
<< kAttrDstFormat << "'";
return false;
}
auto src_format = GetValue<std::string>(op_attrs[kAttrSrcFormat]);
auto dst_format = GetValue<std::string>(op_attrs[kAttrDstFormat]);
if (src_format == kOpFormat_FRAC_NZ && (dst_format == kOpFormat_DEFAULT || dst_format == kOpFormat_NCHW)) {
return true;
}
return (src_format == kOpFormat_DEFAULT || src_format == kOpFormat_NCHW) && dst_format == kOpFormat_FRAC_NZ &&
a->size() == 1 && a->dom()->op() == kCastOpName && !a->is_output();
}
};
} // namespace ascend
void SplitModelAscend::InitFusePatterns() {
AddPattern(std::make_shared<FuseVirtualNode>(), true);
AddPattern(std::make_shared<FuseReshape>(), true);
AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true);
AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true);
AddPattern(FuseReduceFwd::CreateDepthMatcher(inner::ascend::kReduceFusionDepth), true);
AddPattern(FuseReduceFwd::CreateWidthMatcher(inner::ascend::kReduceFusionDepth), true);
AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(inner::ascend::kBroadcastFusionDepth), true);
AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(inner::ascend::kBroadcastFusionDepth), true);
AddPattern(std::make_shared<inner::ascend::FuseMatMul>(), true);
AddPattern(std::make_shared<inner::ascend::FuseReduceBwd>(), true);
AddPattern(std::make_shared<inner::ascend::FuseTransdata>(), true);
}
AreaMode SplitModelAscend::GetDefaultAreaMode(const PrimOpPtr &node) const {
if (node != nullptr && node->op() == kReshapeOpName) {
return AreaMode::BASIC;
}
return AreaMode::COMPOSITE;
}
SPLIT_MODEL_REGISTER(kAscendDevice, SplitModelAscend);
} // namespace mindspore::graphkernel::inner

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_
#include "common/graph_kernel/split_model/split_model_factory.h"
namespace mindspore::graphkernel::inner {
class SplitModelAscend : public SplitModel {
public:
SplitModelAscend() = default;
virtual ~SplitModelAscend() = default;
protected:
AreaMode GetDefaultAreaMode(const PrimOpPtr &node) const override;
void InitFusePatterns() override;
};
} // namespace mindspore::graphkernel::inner
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_

View File

@ -20,9 +20,10 @@
#include <string>
#include "common/graph_kernel/split_model/split_model.h"
#include "utils/hash_map.h"
#include "include/common/visible.h"
namespace mindspore::graphkernel::inner {
class SplitModelFactory {
class COMMON_EXPORT SplitModelFactory {
public:
static SplitModelFactory &Instance() {
static SplitModelFactory instance = SplitModelFactory();

View File

@ -230,7 +230,7 @@ void GetFusionScopeComputeNodeList(const session::KernelGraph *kernel_graph,
while (iter != buffer_fusion_infos->end()) {
if (graphkernel::GraphKernelSupported(iter->second.anf_nodes)) {
MS_LOG(DEBUG) << "Fusion id: " << iter->first << ", uses Graph Kernel Fusion";
buffer_fusion_infos->erase(iter++);
iter = buffer_fusion_infos->erase(iter);
} else {
iter++;
}

View File

@ -89,6 +89,20 @@ def create_compile_dirs(compile_dirs):
def select_best(src_dirs, dst_dir, op_name):
"""Select best compile result."""
def _copy_file(src_path, dst_path):
try:
if os.path.isfile(dst_path):
os.remove(dst_path)
except OSError:
pass
try:
shutil.copy(src_path, dst_path)
except PermissionError:
# If dst_path already exits and only has READ permission
pass
max_block_dim = 1
max_block_dim_idx = -1
for i, src_dir in enumerate(src_dirs):
@ -104,8 +118,8 @@ def select_best(src_dirs, dst_dir, op_name):
if max_block_dim_idx >= 0:
o_path = os.path.join(src_dirs[max_block_dim_idx], op_name + ".o")
json_path = os.path.join(src_dirs[max_block_dim_idx], op_name + ".json")
shutil.copy(o_path, dst_dir)
shutil.copy(json_path, dst_dir)
_copy_file(o_path, os.path.join(dst_dir, op_name + ".o"))
_copy_file(json_path, os.path.join(dst_dir, op_name + ".json"))
logger.info("{}, best compile result dir: {}".format(op_name, src_dirs[max_block_dim_idx]))
return True
logger.info("{}, best compile result dir not found".format(op_name))