forked from mindspore-Ecosystem/mindspore
!40777 use c++ split on Ascend
Merge pull request !40777 from looop5/split_ascend_commit
This commit is contained in:
commit
63c2d7b9e3
|
@ -370,8 +370,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
};
|
||||
|
||||
std::shared_ptr<SplitSchemer> GraphKernelSplitterWithPy::GetSplitSchema(const std::string &processor) {
|
||||
// default use c++ split model for CPU target.
|
||||
if (processor != kCPUDevice) {
|
||||
if (processor != kCPUDevice && processor != kAscendDevice) {
|
||||
MS_LOG(DEBUG) << "use py split model";
|
||||
return std::make_shared<CostModelSplitSchemer>();
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/graph_kernel/adapter/split_model_ascend.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::graphkernel::inner {
|
||||
namespace ascend {
|
||||
constexpr size_t kReduceFusionDepth = 10;
|
||||
constexpr size_t kBroadcastFusionDepth = 6;
|
||||
|
||||
class FuseReduceBwd : public FusePattern {
|
||||
public:
|
||||
FuseReduceBwd() : FusePattern("reduce_bwd") { direction_ = FuseDirection::BACKWARD; }
|
||||
~FuseReduceBwd() = default;
|
||||
|
||||
protected:
|
||||
bool Check(const AreaPtr &dom) override { return dom->IsAlive() && dom->pattern() == NodePattern::REDUCE; }
|
||||
bool Match(const AreaPtr &dom) override {
|
||||
auto op_attrs = dom->dom()->attrs();
|
||||
if (op_attrs.find("reduce_output_fuse") == op_attrs.end()) {
|
||||
return false;
|
||||
}
|
||||
for (auto &[a, r] : dom->users_with_relation()) {
|
||||
if (a->pattern() <= NodePattern::BROADCAST && r == EdgeRelation::INJECTIVE && !HasCircle(dom, a)) {
|
||||
(void)fused_areas_.emplace_back(a);
|
||||
}
|
||||
}
|
||||
return !fused_areas_.empty();
|
||||
}
|
||||
};
|
||||
|
||||
class FuseMatMul : public FusePattern {
|
||||
public:
|
||||
FuseMatMul() : FusePattern("matmul_depth") {}
|
||||
~FuseMatMul() = default;
|
||||
|
||||
protected:
|
||||
bool Check(const AreaPtr &dom) override {
|
||||
return dom->IsAlive() && (dom->dom()->op() == kMatMulOpName || dom->dom()->op() == kBatchMatMulOpName);
|
||||
}
|
||||
bool Match(const AreaPtr &dom) override {
|
||||
auto dom_name = dom->dom()->op();
|
||||
for (auto &a : dom->users()) {
|
||||
if (!a->IsAlive()) {
|
||||
continue;
|
||||
}
|
||||
auto user_name = a->dom()->op();
|
||||
if (((dom_name == kMatMulOpName &&
|
||||
(user_name == kAddNOpName || user_name == kTensorAddOpName || user_name == kCastOpName)) ||
|
||||
(dom_name == kBatchMatMulOpName && a->pattern() == NodePattern::ELEMWISE)) &&
|
||||
!HasCircle(dom, a)) {
|
||||
(void)fused_areas_.emplace_back(a);
|
||||
}
|
||||
}
|
||||
direction_ = FuseDirection::BACKWARD;
|
||||
return !fused_areas_.empty();
|
||||
}
|
||||
};
|
||||
|
||||
class FuseTransdata : public FusePattern {
|
||||
public:
|
||||
FuseTransdata() : FusePattern("transdata") {}
|
||||
~FuseTransdata() = default;
|
||||
|
||||
protected:
|
||||
bool Check(const AreaPtr &dom) override { return dom->IsAlive() && dom->dom()->op() == kTransDataOpName; }
|
||||
bool Match(const AreaPtr &dom) override {
|
||||
for (auto &a : dom->inputs()) {
|
||||
if (a->IsAlive() && Supported(dom, a) && !HasCircle(a, dom)) {
|
||||
(void)fused_areas_.emplace_back(a);
|
||||
}
|
||||
}
|
||||
return !fused_areas_.empty();
|
||||
}
|
||||
|
||||
private:
|
||||
bool NeedPad(const DShape &in_shape, const DShape &out_shape) const {
|
||||
const size_t min_rank = 2;
|
||||
const int64_t block_sz = 16;
|
||||
return !(in_shape.size() >= min_rank && out_shape.size() >= min_rank &&
|
||||
in_shape[in_shape.size() - kIndex1] == block_sz && in_shape[in_shape.size() - kIndex2] == block_sz &&
|
||||
out_shape[out_shape.size() - kIndex1] == block_sz && out_shape[out_shape.size() - kIndex2] == block_sz);
|
||||
}
|
||||
bool Supported(const AreaPtr &dom, const AreaPtr &a) const {
|
||||
if (dom->size() != 1 || dom->dom()->inputs().empty() || NeedPad(dom->dom()->input(0)->shape, dom->dom()->shape)) {
|
||||
return false;
|
||||
}
|
||||
if (a->dom()->op() == kMatMulOpName) {
|
||||
return true;
|
||||
}
|
||||
if (a->pattern() > NodePattern::BROADCAST) {
|
||||
return false;
|
||||
}
|
||||
auto op_attrs = dom->dom()->attrs();
|
||||
if (op_attrs.find(kAttrSrcFormat) == op_attrs.end() || op_attrs.find(kAttrDstFormat) == op_attrs.end()) {
|
||||
MS_LOG(ERROR) << "For '" << dom->dom()->op() << "', can not find the attr '" << kAttrSrcFormat << "' or '"
|
||||
<< kAttrDstFormat << "'";
|
||||
return false;
|
||||
}
|
||||
auto src_format = GetValue<std::string>(op_attrs[kAttrSrcFormat]);
|
||||
auto dst_format = GetValue<std::string>(op_attrs[kAttrDstFormat]);
|
||||
if (src_format == kOpFormat_FRAC_NZ && (dst_format == kOpFormat_DEFAULT || dst_format == kOpFormat_NCHW)) {
|
||||
return true;
|
||||
}
|
||||
return (src_format == kOpFormat_DEFAULT || src_format == kOpFormat_NCHW) && dst_format == kOpFormat_FRAC_NZ &&
|
||||
a->size() == 1 && a->dom()->op() == kCastOpName && !a->is_output();
|
||||
}
|
||||
};
|
||||
} // namespace ascend
|
||||
|
||||
void SplitModelAscend::InitFusePatterns() {
|
||||
AddPattern(std::make_shared<FuseVirtualNode>(), true);
|
||||
AddPattern(std::make_shared<FuseReshape>(), true);
|
||||
AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true);
|
||||
AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true);
|
||||
AddPattern(FuseReduceFwd::CreateDepthMatcher(inner::ascend::kReduceFusionDepth), true);
|
||||
AddPattern(FuseReduceFwd::CreateWidthMatcher(inner::ascend::kReduceFusionDepth), true);
|
||||
AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(inner::ascend::kBroadcastFusionDepth), true);
|
||||
AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(inner::ascend::kBroadcastFusionDepth), true);
|
||||
AddPattern(std::make_shared<inner::ascend::FuseMatMul>(), true);
|
||||
AddPattern(std::make_shared<inner::ascend::FuseReduceBwd>(), true);
|
||||
AddPattern(std::make_shared<inner::ascend::FuseTransdata>(), true);
|
||||
}
|
||||
|
||||
AreaMode SplitModelAscend::GetDefaultAreaMode(const PrimOpPtr &node) const {
|
||||
if (node != nullptr && node->op() == kReshapeOpName) {
|
||||
return AreaMode::BASIC;
|
||||
}
|
||||
return AreaMode::COMPOSITE;
|
||||
}
|
||||
|
||||
SPLIT_MODEL_REGISTER(kAscendDevice, SplitModelAscend);
|
||||
} // namespace mindspore::graphkernel::inner
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_
|
||||
|
||||
#include "common/graph_kernel/split_model/split_model_factory.h"
|
||||
namespace mindspore::graphkernel::inner {
|
||||
class SplitModelAscend : public SplitModel {
|
||||
public:
|
||||
SplitModelAscend() = default;
|
||||
virtual ~SplitModelAscend() = default;
|
||||
|
||||
protected:
|
||||
AreaMode GetDefaultAreaMode(const PrimOpPtr &node) const override;
|
||||
void InitFusePatterns() override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel::inner
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_SPLIT_MODEL_ASCEND_H_
|
|
@ -20,9 +20,10 @@
|
|||
#include <string>
|
||||
#include "common/graph_kernel/split_model/split_model.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore::graphkernel::inner {
|
||||
class SplitModelFactory {
|
||||
class COMMON_EXPORT SplitModelFactory {
|
||||
public:
|
||||
static SplitModelFactory &Instance() {
|
||||
static SplitModelFactory instance = SplitModelFactory();
|
||||
|
|
|
@ -230,7 +230,7 @@ void GetFusionScopeComputeNodeList(const session::KernelGraph *kernel_graph,
|
|||
while (iter != buffer_fusion_infos->end()) {
|
||||
if (graphkernel::GraphKernelSupported(iter->second.anf_nodes)) {
|
||||
MS_LOG(DEBUG) << "Fusion id: " << iter->first << ", uses Graph Kernel Fusion";
|
||||
buffer_fusion_infos->erase(iter++);
|
||||
iter = buffer_fusion_infos->erase(iter);
|
||||
} else {
|
||||
iter++;
|
||||
}
|
||||
|
|
|
@ -89,6 +89,20 @@ def create_compile_dirs(compile_dirs):
|
|||
|
||||
def select_best(src_dirs, dst_dir, op_name):
|
||||
"""Select best compile result."""
|
||||
|
||||
def _copy_file(src_path, dst_path):
|
||||
try:
|
||||
if os.path.isfile(dst_path):
|
||||
os.remove(dst_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
shutil.copy(src_path, dst_path)
|
||||
except PermissionError:
|
||||
# If dst_path already exits and only has READ permission
|
||||
pass
|
||||
|
||||
max_block_dim = 1
|
||||
max_block_dim_idx = -1
|
||||
for i, src_dir in enumerate(src_dirs):
|
||||
|
@ -104,8 +118,8 @@ def select_best(src_dirs, dst_dir, op_name):
|
|||
if max_block_dim_idx >= 0:
|
||||
o_path = os.path.join(src_dirs[max_block_dim_idx], op_name + ".o")
|
||||
json_path = os.path.join(src_dirs[max_block_dim_idx], op_name + ".json")
|
||||
shutil.copy(o_path, dst_dir)
|
||||
shutil.copy(json_path, dst_dir)
|
||||
_copy_file(o_path, os.path.join(dst_dir, op_name + ".o"))
|
||||
_copy_file(json_path, os.path.join(dst_dir, op_name + ".json"))
|
||||
logger.info("{}, best compile result dir: {}".format(op_name, src_dirs[max_block_dim_idx]))
|
||||
return True
|
||||
logger.info("{}, best compile result dir not found".format(op_name))
|
||||
|
|
Loading…
Reference in New Issue