forked from mindspore-Ecosystem/mindspore
refactor dynamic& static op select
This commit is contained in:
parent
6ec539f79e
commit
eabbdb39f3
File diff suppressed because one or more lines are too long
|
@ -787,6 +787,7 @@ constexpr auto kOpFormat_ND = "ND";
|
|||
constexpr auto kOpFormat_NCHW = "NCHW";
|
||||
constexpr auto kOpFormat_NHWC = "NHWC";
|
||||
constexpr auto kOpFormat_HWCN = "HWCN";
|
||||
constexpr auto kOpFormat_CHWN = "CHWN";
|
||||
constexpr auto kOpFormat_NC1HWC0 = "NC1HWC0";
|
||||
constexpr auto kOpFormat_FRAC_Z = "FRACTAL_Z";
|
||||
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
|
||||
|
@ -820,6 +821,7 @@ COMMON_EXPORT bool IsOneOfDynamicShapeConstInputToAttrGPU(const std::string &nam
|
|||
COMMON_EXPORT bool IsOneOfComputeDepend(const std::string &name);
|
||||
COMMON_EXPORT bool IsOneOfHWSpecialFormat(const std::string &format);
|
||||
COMMON_EXPORT bool IsOneOfFormat(const std::string &format);
|
||||
COMMON_EXPORT bool IsOneOfServerFormatC04(const std::string &format);
|
||||
|
||||
// The map between kernel's output and input ref relationship.
|
||||
// Key is the output index while the value is input index which will be used as the reference of output.
|
||||
|
|
|
@ -680,7 +680,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
|
|||
auto axis_attr = primitive->GetAttr(kAxis);
|
||||
if (axis_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
|
||||
return std::vector<int64_t>();
|
||||
return {};
|
||||
}
|
||||
std::vector<int64_t> axis_list;
|
||||
if (axis_attr->isa<Int64Imm>()) {
|
||||
|
|
|
@ -99,12 +99,7 @@ enum FusionType {
|
|||
UNKNOWN_FUSION_TYPE = -1,
|
||||
};
|
||||
|
||||
enum OpPattern {
|
||||
kCommonPattern = 0,
|
||||
kFormatAgnosticPattern = 1,
|
||||
kBroadcastPattern = 2,
|
||||
kReducePattern = 3,
|
||||
};
|
||||
enum OpPattern { kCommonPattern = 0, kFormatAgnosticPattern, kBroadcastPattern, kReducePattern, kDynamicFormatPattern };
|
||||
|
||||
// Backend processor
|
||||
enum Processor {
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_
|
||||
|
||||
namespace mindspore::kernel::refactor {
|
||||
// op info common
|
||||
// dynamicShapeSupport
|
||||
constexpr char kDynamicShapeSupport[] = "dynamicShapeSupport";
|
||||
constexpr char kFlag[] = "flag";
|
||||
constexpr char kTrue[] = "true";
|
||||
constexpr char kFalse[] = "false";
|
||||
// precision_reduce
|
||||
constexpr char kPrecisionReduce[] = "precision_reduce";
|
||||
// opFile
|
||||
constexpr char kOpFile[] = "opFile";
|
||||
constexpr char kValue[] = "value";
|
||||
// opInterface
|
||||
constexpr char kOpInterface[] = "opInterface";
|
||||
|
||||
// attr key
|
||||
constexpr char kAttr[] = "attr";
|
||||
constexpr char kList[] = "list";
|
||||
// attr item key
|
||||
constexpr char kDefaultValue[] = "defaultValue";
|
||||
constexpr char kParamType[] = "paramType";
|
||||
constexpr char kType[] = "type";
|
||||
constexpr char kValue[] = "value";
|
||||
// attr item value
|
||||
constexpr char kInt[] = "int";
|
||||
constexpr char kFloat[] = "float";
|
||||
constexpr char kBool[] = "bool";
|
||||
constexpr char kStr[] = "str";
|
||||
constexpr char kListInt[] = "listInt";
|
||||
constexpr char kListFloat[] = "listFloat";
|
||||
constexpr char kListBool[] = "listBool";
|
||||
constexpr char kListListInt[] = "listListInt";
|
||||
constexpr char kRequired[] = "required";
|
||||
|
||||
// input/output
|
||||
constexpr char kInputSuffix[] = "input";
|
||||
constexpr char kDType[] = "dtype";
|
||||
constexpr char kName[] = "name";
|
||||
constexpr char kUnknownShapeFormat[] = "unknownshape_format";
|
||||
constexpr char kFormat[] = "format";
|
||||
} // namespace mindspore::kernel::refactor
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OP_INFO_CONSTANTS_H_
|
|
@ -68,6 +68,7 @@ class OpIOInfo {
|
|||
const std::string &shape() const { return shape_; }
|
||||
const std::vector<std::string> &dtypes() const { return dtypes_; }
|
||||
const std::vector<std::string> &formats() const { return formats_; }
|
||||
const std::vector<std::string> &unknown_shape_formats() const { return unknown_shape_formats_; }
|
||||
const std::string &value_depend() const { return value_depend_; }
|
||||
|
||||
void set_index(const int index) { index_ = index; }
|
||||
|
@ -78,6 +79,9 @@ class OpIOInfo {
|
|||
void set_shape(const std::string &shape) { shape_ = shape; }
|
||||
void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
|
||||
void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
|
||||
void set_unknown_shape_formats(const std::vector<std::string> &unknown_shape_formats) {
|
||||
unknown_shape_formats_ = unknown_shape_formats;
|
||||
}
|
||||
void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; }
|
||||
|
||||
private:
|
||||
|
@ -89,6 +93,7 @@ class OpIOInfo {
|
|||
std::string shape_;
|
||||
std::vector<std::string> dtypes_;
|
||||
std::vector<std::string> formats_;
|
||||
std::vector<std::string> unknown_shape_formats_;
|
||||
std::string value_depend_ = kIgnored;
|
||||
};
|
||||
|
||||
|
@ -117,6 +122,7 @@ class OpInfo {
|
|||
input_to_attr_index_ = opinfo.input_to_attr_index_;
|
||||
real_input_index_ = opinfo.real_input_index_;
|
||||
need_check_supported_ = opinfo.need_check_supported();
|
||||
dynamic_rank_support_ = opinfo.dynamic_rank_support();
|
||||
is_dynamic_format_ = opinfo.is_dynamic_format();
|
||||
for (const auto &attr : opinfo.attrs_ptr()) {
|
||||
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
|
||||
|
@ -141,6 +147,7 @@ class OpInfo {
|
|||
bool dynamic_compile_static() const { return dynamic_compile_static_; }
|
||||
std::string processor() const { return processor_; }
|
||||
bool need_check_supported() const { return need_check_supported_; }
|
||||
bool dynamic_rank_support() const { return dynamic_rank_support_; }
|
||||
bool is_dynamic_format() const { return is_dynamic_format_; }
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||
|
@ -162,6 +169,7 @@ class OpInfo {
|
|||
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
|
||||
void set_processor(const std::string &processor) { processor_ = processor; }
|
||||
void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; }
|
||||
void set_dynamic_rank_support(bool dynamic_rank_support) { dynamic_rank_support_ = dynamic_rank_support; }
|
||||
void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; }
|
||||
void set_input_to_attr_index(const std::vector<size_t> &input_to_attr_index) {
|
||||
input_to_attr_index_ = input_to_attr_index;
|
||||
|
@ -197,6 +205,7 @@ class OpInfo {
|
|||
bool dynamic_shape_ = false;
|
||||
bool dynamic_compile_static_ = false;
|
||||
bool need_check_supported_ = false;
|
||||
bool dynamic_rank_support_ = false;
|
||||
bool is_dynamic_format_ = false;
|
||||
OpPattern op_pattern_ = kCommonPattern;
|
||||
std::string processor_;
|
||||
|
@ -211,5 +220,6 @@ class OpInfo {
|
|||
using OpAttrPtr = std::shared_ptr<OpAttr>;
|
||||
using OpIOInfoPtr = std::shared_ptr<OpIOInfo>;
|
||||
using OpInfoPtr = std::shared_ptr<OpInfo>;
|
||||
extern std::vector<std::string> SplitStrToVec(const std::string &input);
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/oplib/oplib.h"
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
|
@ -39,11 +39,13 @@ constexpr auto kIsDynamicFormat = "is_dynamic_format";
|
|||
constexpr auto kDynamicFormat = "dynamicFormat";
|
||||
constexpr auto kFormatAgnostic = "formatAgnostic";
|
||||
constexpr auto kNeedCheckSupported = "need_check_supported";
|
||||
constexpr auto kDynamicRankSupport = "dynamic_rank_support";
|
||||
constexpr auto kBroadcast = "broadcast";
|
||||
constexpr auto kReduce = "reduce";
|
||||
constexpr auto kDynamicShape = "dynamic_shape";
|
||||
constexpr auto kDynamicCompileStatic = "dynamic_compile_static";
|
||||
constexpr auto kDtypeFormat = "dtype_format";
|
||||
constexpr auto kUnknownShapeFormat = "unknown_shape_format";
|
||||
constexpr auto kInputToAttrIndex = "input_to_attr_index";
|
||||
constexpr auto kRealInputIndex = "real_input_index";
|
||||
constexpr auto kAttr = "attr";
|
||||
|
@ -68,6 +70,9 @@ constexpr auto kNeedCompile = "need_compile";
|
|||
constexpr auto kShape = "shape";
|
||||
constexpr auto kProcessor = "processor";
|
||||
|
||||
static const std::map<std::string, OpImplyType> OpImplyTypeMap = {
|
||||
{kTbe, kTBE}, {kAkg, kAKG}, {kAiCPU, kAICPU}, {kCpu, kCPU}, {kGpu, kGPU}};
|
||||
|
||||
static std::string ImplTypeToStr(OpImplyType impl_type) {
|
||||
switch (impl_type) {
|
||||
case kTBE:
|
||||
|
@ -84,37 +89,52 @@ static std::string ImplTypeToStr(OpImplyType impl_type) {
|
|||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> SplitStrToVec(const std::string &input) {
|
||||
static const std::map<std::string, std::string> kSpecFormat = {{kOpFormat_NCHW, kOpFormat_DEFAULT},
|
||||
{kOpFormat_ND, kOpFormat_DEFAULT}};
|
||||
if (input.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Op select ret item is null.";
|
||||
}
|
||||
// remove blank elem
|
||||
std::string input_tmp = input;
|
||||
(void)input_tmp.erase(remove(input_tmp.begin(), input_tmp.end(), ' '), input_tmp.end());
|
||||
(void)input_tmp.append(",");
|
||||
// split
|
||||
const char sep = ',';
|
||||
std::vector<std::string> result = {};
|
||||
auto begin = 0U;
|
||||
auto end = input_tmp.find(sep);
|
||||
while (end != std::string::npos) {
|
||||
auto format = input_tmp.substr(begin, end - begin);
|
||||
auto find_iter = kSpecFormat.find(format);
|
||||
if (find_iter != kSpecFormat.end()) {
|
||||
format = find_iter->second;
|
||||
}
|
||||
(void)result.emplace_back(format);
|
||||
begin = end + 1;
|
||||
end = input_tmp.find(sep, begin);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) {
|
||||
bool ret = false;
|
||||
try {
|
||||
auto op_json = nlohmann::json::parse(json_string);
|
||||
std::string imply_type_string = op_json.at(kImplyType);
|
||||
std::string op_name = op_json.at(kOpName);
|
||||
if (imply_type_string == kTbe) {
|
||||
OpImplyType imply_type = kTBE;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kAkg) {
|
||||
OpImplyType imply_type = kAKG;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kAiCPU) {
|
||||
OpImplyType imply_type = kAICPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kCpu) {
|
||||
OpImplyType imply_type = kCPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else if (imply_type_string == kGpu) {
|
||||
OpImplyType imply_type = kGPU;
|
||||
ret = DecodeOpInfo(op_json, imply_type, impl_path);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support imply_type";
|
||||
auto find_iter = OpImplyTypeMap.find(imply_type_string);
|
||||
if (find_iter == OpImplyTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "Not support imply_type, " << imply_type_string;
|
||||
return false;
|
||||
}
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string;
|
||||
if (!DecodeOpInfo(op_json, find_iter->second, impl_path)) {
|
||||
MS_LOG(ERROR) << "RegOp failed: op_name: " << op_json.at(kOpName) << " imply_type " << imply_type_string;
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "get op json elements failed: " << e.what();
|
||||
}
|
||||
return ret;
|
||||
return true;
|
||||
}
|
||||
|
||||
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
||||
|
@ -127,6 +147,9 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
op_info->set_kernel_name(obj.at(kKernelName));
|
||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||
op_info->set_need_check_supported(obj.at(kNeedCheckSupported));
|
||||
if (obj.find(kDynamicRankSupport) != obj.end()) {
|
||||
op_info->set_dynamic_rank_support(obj.at(kDynamicRankSupport));
|
||||
}
|
||||
|
||||
if (obj.find(kDynamicShape) != obj.end()) {
|
||||
op_info->set_dynamic_shape(obj.at(kDynamicShape));
|
||||
|
@ -136,8 +159,13 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
op_info->set_dynamic_compile_static_(obj.at(kDynamicCompileStatic));
|
||||
}
|
||||
|
||||
if (obj.find(kIsDynamicFormat) != obj.end()) {
|
||||
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
|
||||
auto dynamic_iter = obj.find(kIsDynamicFormat);
|
||||
if (dynamic_iter != obj.end()) {
|
||||
bool is_dynamic_format = dynamic_iter->get<bool>();
|
||||
if (is_dynamic_format) {
|
||||
op_info->set_op_pattern(kDynamicFormatPattern);
|
||||
}
|
||||
op_info->set_is_dynamic_format(is_dynamic_format);
|
||||
}
|
||||
|
||||
if (obj.find(kInputToAttrIndex) != obj.end()) {
|
||||
|
@ -158,13 +186,12 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
if (obj.find(kOpPattern) != obj.end()) {
|
||||
std::string op_pattern = obj.at(kOpPattern);
|
||||
auto find_iter = kOpPatternMap.find(op_pattern);
|
||||
if (find_iter == kOpPatternMap.end()) {
|
||||
if (!op_pattern.empty()) {
|
||||
MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern;
|
||||
}
|
||||
op_info->set_op_pattern(kCommonPattern);
|
||||
} else {
|
||||
if (find_iter != kOpPatternMap.end()) {
|
||||
op_info->set_op_pattern(find_iter->second);
|
||||
} else {
|
||||
if (!op_pattern.empty()) {
|
||||
MS_LOG(WARNING) << "The error pattern: " << op_pattern;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -265,6 +292,24 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (obj.find(kUnknownShapeFormat) != obj.end()) {
|
||||
auto unknown_shape_formats_obj = obj.at(kUnknownShapeFormat);
|
||||
if (unknown_shape_formats_obj.size() != op_info->inputs_ptr().size() + op_info->outputs_ptr().size()) {
|
||||
MS_LOG(ERROR) << "If unknown shape exist, the size should be equal (input size + output size).";
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < op_info->inputs_ptr().size(); ++i) {
|
||||
auto unknown_shape_formats_str = unknown_shape_formats_obj.at(i);
|
||||
auto unknown_shape_formats = SplitStrToVec(unknown_shape_formats_str);
|
||||
op_info->inputs_ptr().at(i)->set_unknown_shape_formats(unknown_shape_formats);
|
||||
}
|
||||
for (size_t i = 0; i < op_info->outputs_ptr().size(); ++i) {
|
||||
auto index = i + op_info->inputs_ptr().size();
|
||||
auto unknown_shape_formats_str = unknown_shape_formats_obj.at(index);
|
||||
auto unknown_shape_formats = SplitStrToVec(unknown_shape_formats_str);
|
||||
op_info->outputs_ptr().at(i)->set_unknown_shape_formats(unknown_shape_formats);
|
||||
}
|
||||
}
|
||||
if (CheckRepetition(op_info)) {
|
||||
MS_LOG(WARNING) << "This op info has been already registered. op name: " << op_info->op_name()
|
||||
<< ", impl type: " << ImplTypeToStr(op_info->imply_type())
|
||||
|
|
|
@ -102,6 +102,12 @@ std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name,
|
|||
return op_info;
|
||||
}
|
||||
|
||||
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
return FindOp(op_name, cnode);
|
||||
}
|
||||
|
||||
inline std::string GetPrimitiveName(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return "";
|
||||
|
|
|
@ -36,6 +36,7 @@ class TbeDynamicShapeUtil {
|
|||
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
|
||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
|
||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
|
||||
static std::shared_ptr<OpInfo> FindOp(const CNodePtr &cnode);
|
||||
static RangePair GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
||||
const TypeId &type);
|
||||
static RangePair GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index, const std::string &def_format,
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/agnostic_selector/tbe_kernel_agnostic_selector.h"
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void TbeKernelAgnosticSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||
SupportFormat support_format;
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
if (input_num != 1 || output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Agnostic only support one input. input_num: " << input_num << ", output num: " << output_num
|
||||
<< ", full_name:" << cnode_ptr_->fullname_with_scope();
|
||||
}
|
||||
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0);
|
||||
if (!IsOneOfFormat(format)) {
|
||||
MS_LOG(ERROR) << "Got the unknown format " << format;
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(format, input_num, format, output_num, &support_format);
|
||||
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TbeKernelAgnosticSelector {
|
||||
public:
|
||||
explicit TbeKernelAgnosticSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||
~TbeKernelAgnosticSelector() = default;
|
||||
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
|
||||
|
||||
private:
|
||||
CNodePtr cnode_ptr_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_AGNOSTIC_SELECTOR_H_
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_selector/tbe_kernel_common_selector.h"
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
void TbeKernelCommonSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(support_format_dtype);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(cnode_ptr_);
|
||||
// TODO(jjfeing) usr format as dynamic format
|
||||
for (const auto &input : op_info->inputs_ptr()) {
|
||||
(void)support_format_dtype->input_dtypes.emplace_back(input->dtypes());
|
||||
if (is_dynamic_shape) {
|
||||
// TODO(jjfeing) (void)support_format_dtype->input_formats.emplace_back(input->unknown_shape_formats());
|
||||
(void)support_format_dtype->input_formats.emplace_back(input->formats());
|
||||
} else {
|
||||
(void)support_format_dtype->input_formats.emplace_back(input->formats());
|
||||
}
|
||||
}
|
||||
for (const auto &output : op_info->outputs_ptr()) {
|
||||
(void)support_format_dtype->output_dtypes.emplace_back(output->dtypes());
|
||||
if (is_dynamic_shape) {
|
||||
// TODO(jjfeing) (void)support_format_dtype->output_formats.emplace_back(output->unknown_shape_formats());
|
||||
(void)support_format_dtype->output_formats.emplace_back(output->formats());
|
||||
} else {
|
||||
(void)support_format_dtype->output_formats.emplace_back(output->formats());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TbeKernelCommonSelector {
|
||||
public:
|
||||
explicit TbeKernelCommonSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||
~TbeKernelCommonSelector() = default;
|
||||
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
|
||||
|
||||
private:
|
||||
CNodePtr cnode_ptr_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_COMMON_SELECTOR_H_
|
|
@ -1,131 +0,0 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "base/base.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kNcdhwShapeSize = 5;
|
||||
|
||||
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
return param->input_size() > 0 && param->hidden_size() > 0;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < real_input_num; i++) {
|
||||
auto format = AnfAlgo::GetInputFormat(node, i);
|
||||
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(node, i);
|
||||
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
|
||||
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
infer_shape = shape_ptr->shape();
|
||||
}
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
auto reshape_type =
|
||||
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
|
||||
}
|
||||
|
||||
auto temp_shape = infer_shape;
|
||||
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && infer_shape.size() < kShape4dDims &&
|
||||
!IsOneOf3DFormat(format)) {
|
||||
MS_LOG(DEBUG) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape, node);
|
||||
}
|
||||
if (infer_shape.size() != kNcdhwShapeSize && IsOneOf3DFormat(format)) {
|
||||
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape, node);
|
||||
}
|
||||
return temp_shape;
|
||||
}
|
||||
|
||||
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
|
||||
if (infer_shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
|
||||
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
|
||||
kOpFormat_NC1HWC0_C04};
|
||||
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
if (check_4D_format.find(format) != check_4D_format.end()) {
|
||||
return infer_shape.size() == kShape4dDims;
|
||||
}
|
||||
if (check_5D_format.find(format) != check_5D_format.end()) {
|
||||
return infer_shape.size() == kShape5dDims;
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return infer_shape.size() >= kShape2dDims ||
|
||||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
|
||||
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_SELECT_COMMON_UTILS_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct SupportFormat {
|
||||
std::vector<std::vector<std::string>> input_format;
|
||||
std::vector<std::vector<std::string>> output_format;
|
||||
};
|
||||
using SupportFormatItem = std::vector<std::string>;
|
||||
|
||||
class HostCheck {
|
||||
public:
|
||||
HostCheck() = default;
|
||||
~HostCheck() = default;
|
||||
static bool CheckValidDeviceShape(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_COMMON_UTILS_H_
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/dynamic_selector/tbe_kernel_dynamic_selector.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/json_operation_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_compile.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr auto kDType1 = "dtype";
|
||||
constexpr auto kFormat1 = "format";
|
||||
constexpr auto kUnknownSahpeFormat1 = "unknownshape_format";
|
||||
constexpr auto kPrefixInput1 = "input";
|
||||
constexpr auto kPrefixOutput1 = "output";
|
||||
namespace {
|
||||
void ParseOpSelectJson(bool is_dynamic_impl, const std::string &op_select_json,
|
||||
SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(support_format_dtype);
|
||||
nlohmann::json json_obj;
|
||||
if (!ParseJson(op_select_json, &json_obj)) {
|
||||
MS_LOG(EXCEPTION) << "Parse op_select_json error.";
|
||||
}
|
||||
if (!json_obj.is_object()) {
|
||||
MS_LOG(EXCEPTION) << "JsonStr is not an object, the json is:" << op_select_json;
|
||||
}
|
||||
for (const auto &item : json_obj.items()) {
|
||||
const std::string &item_name = item.key();
|
||||
bool is_input = (item_name.find(kPrefixInput1) != std::string::npos);
|
||||
bool is_output = (item_name.find(kPrefixOutput1) != std::string::npos);
|
||||
if (!is_input && !is_output) {
|
||||
MS_LOG(EXCEPTION) << "Op select ret json is error, the json is:" << op_select_json;
|
||||
}
|
||||
auto dtypes = SplitStrToVec(item.value().at(kDType1));
|
||||
std::string formats_str = item.value().at(kFormat1);
|
||||
if (is_dynamic_impl && item.value().find(kUnknownSahpeFormat1) != item.value().end()) {
|
||||
formats_str = item.value().at(kUnknownSahpeFormat1);
|
||||
}
|
||||
auto formats = SplitStrToVec(formats_str);
|
||||
if (is_input) {
|
||||
(void)support_format_dtype->input_dtypes.emplace_back(dtypes);
|
||||
(void)support_format_dtype->input_formats.emplace_back(formats);
|
||||
} else {
|
||||
(void)support_format_dtype->output_dtypes.emplace_back(dtypes);
|
||||
(void)support_format_dtype->output_formats.emplace_back(formats);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TbeKernelDynamicSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
|
||||
std::string op_format_dtype_str = build_manager.TbeOpSelectFormat(cnode_ptr_);
|
||||
if (op_format_dtype_str.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Op select format failed, "
|
||||
<< "node name: " << cnode_ptr_->fullname_with_scope();
|
||||
}
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
bool is_dynamic_impl = op_info->dynamic_compile_static() || common::AnfAlgo::IsDynamicShape(cnode_ptr_);
|
||||
ParseOpSelectJson(is_dynamic_impl, op_format_dtype_str, support_format_dtype);
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TbeKernelDynamicSelector {
|
||||
public:
|
||||
explicit TbeKernelDynamicSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||
~TbeKernelDynamicSelector() = default;
|
||||
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
|
||||
|
||||
private:
|
||||
CNodePtr cnode_ptr_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_DYNAMIC_SELECTOR_H_
|
|
@ -0,0 +1,393 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_broadcast_selector.h"
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr size_t kChannelN = 0;
|
||||
constexpr size_t kChannelC = 1;
|
||||
constexpr int64_t kAlignmented16 = 16;
|
||||
constexpr int64_t kDynamicInputSize = 2;
|
||||
constexpr int64_t kLastIndex = 1;
|
||||
constexpr int64_t kLastButOneIndex = 2;
|
||||
|
||||
void TbeKernelBroadcastSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
SupportFormat support_format;
|
||||
GetSupportOriFormat(cnode_ptr_, &support_format);
|
||||
GetBroadCastNodeInfo();
|
||||
GetCheckInfo();
|
||||
GetBroadcastSupport5HD(&support_format);
|
||||
GetBroadcastSupportFracZ(&support_format);
|
||||
GetBroadcastSupportC1HWNCoC0(&support_format);
|
||||
GetBroadcastSupportFracNZ(&support_format);
|
||||
GetBroadcastSupportNDC1HWC0(&support_format);
|
||||
GetBroadcastSupportFracZ3D(&support_format);
|
||||
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadCastNodeInfo() {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) {
|
||||
auto dynamic_size_vec = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode_ptr_, kAttrDynInputSizes);
|
||||
if (dynamic_size_vec.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << common::AnfAlgo::GetCNodeName(cnode_ptr_)
|
||||
<< "]'s attr [dyn_input_sizes] is empty" << trace::DumpSourceLines(cnode_ptr_);
|
||||
}
|
||||
if (dynamic_size_vec[0] < kDynamicInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Node [" << common::AnfAlgo::GetCNodeName(cnode_ptr_)
|
||||
<< "]'s attr [dyn_input_sizes] value less than " << kDynamicInputSize
|
||||
<< trace::DumpSourceLines(cnode_ptr_);
|
||||
}
|
||||
auto dynamic_input_shape0_ = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kIndex0);
|
||||
PadScalarShape(&dynamic_input_shape0_);
|
||||
(void)input_shapes_.emplace_back(dynamic_input_shape0_);
|
||||
input_num_ = 1;
|
||||
} else {
|
||||
input_num_ = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
for (size_t i = 0; i < input_num_; ++i) {
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&input_shape);
|
||||
(void)input_shapes_.emplace_back(input_shape);
|
||||
}
|
||||
}
|
||||
|
||||
output_num_ = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&output);
|
||||
(void)output_shapes_.emplace_back(output);
|
||||
}
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetCheckInfo() {
|
||||
is_same_shape_ = IsSameShape();
|
||||
has_scalar_input_ = HasScalarInput();
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupport5HD(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
// case1: same shape & no scalar input
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
GenerateSupportFormat(kOpFormat_NC1HWC0, input_num_, kOpFormat_NC1HWC0, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (has_scalar_input_) {
|
||||
// case2: part no scalar input.
|
||||
// scalar input: default | no scalar input(4D, channel_c % 16 == 0): 5hd
|
||||
// scalar output: default | no scalar input: 5hd
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
(void)input_support_format.emplace_back(kOpFormat_NC1HWC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// case3: has no scalar inputs
|
||||
// input shape dims == 4 && channel_c can not been broadcast
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (IsInputBroadcastByChannel(kChannelC)) {
|
||||
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||
return;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_NC1HWC0);
|
||||
}
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupportFracZ(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
GenerateSupportFormat(kOpFormat_FRAC_Z, input_num_, kOpFormat_FRAC_Z, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
(void)input_support_format.emplace_back(kOpFormat_FRAC_Z);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// case3: has no scalar inputs
|
||||
// input shape dims == 4 && channel_c can not been broadcast
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
|
||||
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||
return;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_FRAC_Z);
|
||||
}
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupportC1HWNCoC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
GenerateSupportFormat(kOpFormat_C1HWNCoC0, input_num_, kOpFormat_C1HWNCoC0, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (has_scalar_input_) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelN] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
(void)input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is4DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
|
||||
MS_LOG(INFO) << "This node broadcast n || c channel.";
|
||||
return;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupportFracNZ(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_num_, kOpFormat_FRAC_NZ, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (has_scalar_input_) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (shape.size() < kShape2dDims) {
|
||||
return;
|
||||
}
|
||||
if (shape[shape.size() - kLastIndex] % kAlignmented16 != 0 ||
|
||||
shape[shape.size() - kLastButOneIndex] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(),
|
||||
[](const ShapeVector &elem) { return elem.size() < kShape2dDims; });
|
||||
if (less_2dims) {
|
||||
MS_LOG(INFO) << "This node dim less 2.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (IsInputBroadcastByChannel(input_shapes_.front().size() - kLastIndex) ||
|
||||
IsInputBroadcastByChannel(input_shapes_.front().size() - kLastButOneIndex)) {
|
||||
MS_LOG(INFO) << "This node broadcast last channel.";
|
||||
return;
|
||||
}
|
||||
|
||||
input_support_format.assign(input_num_, kOpFormat_FRAC_NZ);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
if (input_shapes_.front().size() < kShape5dDims) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_NDC1HWC0, input_num_, kOpFormat_NDC1HWC0, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (has_scalar_input_) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is5DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
(void)input_support_format.emplace_back(kOpFormat_NDC1HWC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is5DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (IsInputBroadcastByChannel(kChannelC)) {
|
||||
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||
return;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_NDC1HWC0);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GetBroadcastSupportFracZ3D(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_same_shape_ && !has_scalar_input_) {
|
||||
if (input_shapes_.front().size() < kShape5dDims) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_FRACTAL_Z_3D, input_num_, kOpFormat_FRACTAL_Z_3D, output_num_, support_format);
|
||||
return;
|
||||
}
|
||||
SupportFormatItem input_support_format;
|
||||
SupportFormatItem output_support_format;
|
||||
if (has_scalar_input_) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is5DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
if (shape[kChannelN] % kAlignmented16 != 0) {
|
||||
return;
|
||||
}
|
||||
(void)input_support_format.emplace_back(kOpFormat_FRACTAL_Z_3D);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (!Is5DShape(shape)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (IsInputBroadcastByChannel(kChannelC) || IsInputBroadcastByChannel(kChannelN)) {
|
||||
MS_LOG(INFO) << "This node broadcast c channel.";
|
||||
return;
|
||||
}
|
||||
input_support_format.assign(input_num_, kOpFormat_FRACTAL_Z_3D);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_FRACTAL_Z_3D, &output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
}
|
||||
|
||||
bool TbeKernelBroadcastSelector::Is4DShape(const ShapeVector &shape) { return shape.size() == kShape4dDims; }
|
||||
|
||||
bool TbeKernelBroadcastSelector::Is5DShape(const ShapeVector &shape) { return shape.size() == kShape5dDims; }
|
||||
|
||||
bool TbeKernelBroadcastSelector::IsSameShape() const {
|
||||
auto shape = input_shapes_.begin();
|
||||
for (const auto &item : input_shapes_) {
|
||||
if (shape->size() != item.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < shape->size(); ++i) {
|
||||
if (shape->at(i) != item.at(i) || (shape->at(i) == -1 && item.at(i) == -1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelBroadcastSelector::IsScalarShape(const ShapeVector &shape) {
|
||||
return (shape.size() == 1 && shape[0] == 1);
|
||||
}
|
||||
|
||||
bool TbeKernelBroadcastSelector::HasScalarInput() const {
|
||||
return std::any_of(input_shapes_.begin(), input_shapes_.end(),
|
||||
[&](const auto &shape) { return this->IsScalarShape(shape); });
|
||||
}
|
||||
|
||||
bool TbeKernelBroadcastSelector::IsInputBroadcastByChannel(size_t channel) const {
|
||||
ShapeValueDType left = 0;
|
||||
if (channel < input_shapes_.front().size()) {
|
||||
left = input_shapes_.front()[channel];
|
||||
}
|
||||
return std::any_of(input_shapes_.begin() + 1, input_shapes_.end(), [&left, &channel](const ShapeVector &elem) {
|
||||
ShapeValueDType right = 0;
|
||||
if (channel < elem.size()) {
|
||||
right = elem[channel];
|
||||
}
|
||||
return (left != right || (left == -1 && right == -1));
|
||||
});
|
||||
}
|
||||
|
||||
void TbeKernelBroadcastSelector::GenOutputSupportFormat(const std::string &support_format,
|
||||
SupportFormatItem *output_support_item) const {
|
||||
MS_EXCEPTION_IF_NULL(output_support_item);
|
||||
for (const auto &shape : output_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
(void)output_support_item->emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
(void)output_support_item->emplace_back(support_format);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TbeKernelBroadcastSelector {
|
||||
public:
|
||||
explicit TbeKernelBroadcastSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||
~TbeKernelBroadcastSelector() = default;
|
||||
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
|
||||
|
||||
private:
|
||||
void GetBroadCastNodeInfo();
|
||||
void GetCheckInfo();
|
||||
void GetBroadcastSupport5HD(SupportFormat *support_format) const;
|
||||
void GetBroadcastSupportFracZ(SupportFormat *support_format) const;
|
||||
void GetBroadcastSupportC1HWNCoC0(SupportFormat *support_format) const;
|
||||
void GetBroadcastSupportFracNZ(SupportFormat *support_format) const;
|
||||
void GetBroadcastSupportNDC1HWC0(SupportFormat *support_format) const;
|
||||
void GetBroadcastSupportFracZ3D(SupportFormat *support_format) const;
|
||||
[[nodiscard]] inline bool IsSameShape() const;
|
||||
[[nodiscard]] static inline bool Is4DShape(const ShapeVector &shape);
|
||||
[[nodiscard]] static inline bool Is5DShape(const ShapeVector &shape);
|
||||
[[nodiscard]] static inline bool IsScalarShape(const ShapeVector &shape);
|
||||
[[nodiscard]] inline bool HasScalarInput() const;
|
||||
[[nodiscard]] inline bool IsInputBroadcastByChannel(size_t channel) const;
|
||||
void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const;
|
||||
// broadcast
|
||||
CNodePtr cnode_ptr_ = nullptr;
|
||||
size_t input_num_ = 0;
|
||||
size_t output_num_ = 0;
|
||||
std::vector<ShapeVector> input_shapes_{};
|
||||
std::vector<ShapeVector> output_shapes_{};
|
||||
// check info
|
||||
bool is_same_shape_ = false;
|
||||
bool has_scalar_input_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_BROADCAST_SELECTOR_H_
|
|
@ -0,0 +1,295 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_reduce_selector.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
#include "plugin/device/ascend/hal/hardware/ascend_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/util/platform_info.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
constexpr int64_t kChannelN = 0;
|
||||
constexpr int64_t kChannelC = 1;
|
||||
constexpr size_t kReduceNZMinDim = 2;
|
||||
const int64_t kCubeSize = 16;
|
||||
|
||||
void TbeKernelReduceSelector::GetSupportedFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
SupportFormat support_format;
|
||||
// step1: set ori support
|
||||
GetSupportOriFormat(cnode_ptr_, &support_format);
|
||||
// step2: get reduce node info
|
||||
GetReduceNodeInfo();
|
||||
// step3: generate support format
|
||||
GetReduceSupport5HD(&support_format);
|
||||
GetReduceSupportNDC1HWC0(&support_format);
|
||||
GetReduceSupportFracZ(&support_format);
|
||||
GetReduceSupportC1HWNCoC0(&support_format);
|
||||
GetReduceSupportFracZ3D(&support_format);
|
||||
GetReduceSupportFracNZ(&support_format);
|
||||
GenerateSupportFormatDType(cnode_ptr_, support_format, support_format_dtype);
|
||||
FilterInvalidFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceNodeInfo() {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
if (input_num != 1 || output_num != 1) {
|
||||
MS_LOG(WARNING) << "Reduce operator input/output is not 1, input num: " << input_num
|
||||
<< ", output num: " << output_num << ", node info: " << cnode_ptr_->DebugString();
|
||||
}
|
||||
// get input/output shape
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&shape);
|
||||
(void)input_shape_.emplace_back(shape);
|
||||
}
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&shape);
|
||||
(void)output_shape_.emplace_back(shape);
|
||||
}
|
||||
// get keep dim attr
|
||||
GetReduceAttrKeepDim();
|
||||
// get axis attr
|
||||
axis_ = GetReduceAttrAxis(cnode_ptr_);
|
||||
std::transform(axis_.begin(), axis_.end(), axis_.begin(), [&](int64_t elem) {
|
||||
if (elem < 0) {
|
||||
elem += SizeToLong(input_shape_.at(kIndex0).size());
|
||||
}
|
||||
return elem;
|
||||
});
|
||||
GetCheckInfo();
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetCheckInfo() {
|
||||
is_shape_4_dims_ = CheckOriginInputShapeDimEqual(kShape4dDims);
|
||||
is_shape_5_dims_ = CheckOriginInputShapeDimEqual(kShape5dDims);
|
||||
is_shape_less_2_dims_ = CheckOriginInputShapeDimLess(kShape2dDims);
|
||||
is_shape_less_4_dims_ = CheckOriginInputShapeDimLess(kShape4dDims);
|
||||
is_shape_less_5_dims_ = CheckOriginInputShapeDimLess(kShape5dDims);
|
||||
is_reduce_c_channel_ = CheckReduceContainChanel(kChannelC);
|
||||
is_reduce_n_channel_ = CheckReduceContainChanel(kChannelN);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupport5HD(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
// Note: if input and output size == 2, the last index only support 5hd(float16) --> default
|
||||
if (!is_shape_4_dims_) {
|
||||
return;
|
||||
}
|
||||
auto support_output_format = is_reduce_c_channel_ ? kOpFormat_DEFAULT : kOpFormat_NC1HWC0;
|
||||
GenerateSupportFormat(kOpFormat_NC1HWC0, input_shape_.size(), support_output_format, output_shape_.size(),
|
||||
support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupportNDC1HWC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (!is_shape_5_dims_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_c_channel_) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_NDC1HWC0, input_shape_.size(), kOpFormat_NDC1HWC0, output_shape_.size(),
|
||||
support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupportFracZ(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_shape_less_4_dims_) {
|
||||
return;
|
||||
}
|
||||
if (!keep_dims_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_c_channel_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_n_channel_) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_FRAC_Z, input_shape_.size(), kOpFormat_FRAC_Z, output_shape_.size(), support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupportC1HWNCoC0(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_shape_less_4_dims_) {
|
||||
return;
|
||||
}
|
||||
if (!keep_dims_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_c_channel_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_n_channel_) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_C1HWNCoC0, input_shape_.size(), kOpFormat_C1HWNCoC0, output_shape_.size(),
|
||||
support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupportFracZ3D(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_shape_less_5_dims_) {
|
||||
return;
|
||||
}
|
||||
if (!keep_dims_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_c_channel_) {
|
||||
return;
|
||||
}
|
||||
if (is_reduce_n_channel_) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_FRACTAL_Z_3D, input_shape_.size(), kOpFormat_FRACTAL_Z_3D, output_shape_.size(),
|
||||
support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceSupportFracNZ(SupportFormat *support_format) const {
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
if (is_shape_less_2_dims_) {
|
||||
return;
|
||||
}
|
||||
if (!keep_dims_) {
|
||||
return;
|
||||
}
|
||||
int64_t last_channel = SizeToLong(input_shape_.at(0).size()) - 1;
|
||||
bool is_reduce_last = CheckReduceContainChanel(last_channel);
|
||||
int64_t last_channel_but_one = SizeToLong(input_shape_.at(0).size()) - 2;
|
||||
bool is_reduce_last_but_one = CheckReduceContainChanel(last_channel_but_one);
|
||||
if (!is_reduce_last && !is_reduce_last_but_one) {
|
||||
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_shape_.size(), kOpFormat_FRAC_NZ, output_shape_.size(),
|
||||
support_format);
|
||||
return;
|
||||
}
|
||||
if (last_channel >= 0 && LongToSize(last_channel) < input_shape_.at(kIndex0).size() &&
|
||||
input_shape_.at(kIndex0).at(last_channel) % kCubeSize != 0) {
|
||||
return;
|
||||
}
|
||||
if (last_channel_but_one >= 0 && LongToSize(last_channel_but_one) < input_shape_.at(kIndex0).size() &&
|
||||
input_shape_.at(kIndex0).at(last_channel_but_one) % kCubeSize != 0) {
|
||||
return;
|
||||
}
|
||||
GenerateSupportFormat(kOpFormat_FRAC_NZ, input_shape_.size(), kOpFormat_DEFAULT, output_shape_.size(),
|
||||
support_format);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::GetReduceAttrKeepDim() {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) {
|
||||
MS_LOG(WARNING) << "This node doesn't have keep_attr.";
|
||||
keep_dims_ = false;
|
||||
return;
|
||||
}
|
||||
keep_dims_ = common::AnfAlgo::GetNodeAttr<bool>(cnode_ptr_, kAttrKeepDims);
|
||||
}
|
||||
|
||||
void TbeKernelReduceSelector::FilterInvalidFormatDType(SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(support_format_dtype);
|
||||
if (support_format_dtype->input_dtypes.size() != 1 || support_format_dtype->output_dtypes.size() != 1) {
|
||||
MS_LOG(WARNING) << "The reduce node input or output num is not 1.";
|
||||
return;
|
||||
}
|
||||
|
||||
SupportDTypeItem input_dtypes = support_format_dtype->input_dtypes.at(0);
|
||||
SupportFormatItem input_formats = support_format_dtype->input_formats.at(0);
|
||||
SupportDTypeItem output_dtypes = support_format_dtype->output_dtypes.at(0);
|
||||
SupportFormatItem output_formats = support_format_dtype->output_formats.at(0);
|
||||
|
||||
SupportDTypeItem input_dtypes_new;
|
||||
SupportFormatItem input_formats_new;
|
||||
SupportDTypeItem output_dtypes_new;
|
||||
SupportFormatItem output_formats_new;
|
||||
for (size_t i = 0; i < input_dtypes.size(); ++i) {
|
||||
auto input_dtype = input_dtypes.at(i);
|
||||
auto input_format = input_formats.at(i);
|
||||
auto output_dtype = output_dtypes.at(i);
|
||||
auto output_format = output_formats.at(i);
|
||||
if (input_format == kOpFormat_NC1HWC0 && output_format == kOpFormat_DEFAULT && input_dtype == "float16") {
|
||||
MS_LOG(WARNING) << "Input 5hd, input type fp16 ane output default not supported.";
|
||||
continue;
|
||||
}
|
||||
// TODO(jjfeing)
|
||||
if (input_format == kOpFormat_NC1HWC0 && !CheckUBSizeEnable(input_format, input_dtype)) {
|
||||
MS_LOG(WARNING) << "Input 5hd, total size need to greater than block size.";
|
||||
continue;
|
||||
}
|
||||
(void)input_dtypes_new.emplace_back(input_dtype);
|
||||
(void)input_formats_new.emplace_back(input_format);
|
||||
(void)output_dtypes_new.emplace_back(output_dtype);
|
||||
(void)output_formats_new.emplace_back(output_format);
|
||||
}
|
||||
support_format_dtype->input_dtypes = {input_dtypes_new};
|
||||
support_format_dtype->input_formats = {input_formats_new};
|
||||
support_format_dtype->output_dtypes = {output_dtypes_new};
|
||||
support_format_dtype->output_formats = {output_formats_new};
|
||||
}
|
||||
|
||||
bool TbeKernelReduceSelector::CheckUBSizeEnable(const std::string &input_format, const std::string &input_dtype) {
|
||||
if (axis_.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (keep_dims_) {
|
||||
return true;
|
||||
}
|
||||
auto soc_version = device::ascend::GetSocVersion();
|
||||
fe::PlatformInfo platform_info;
|
||||
fe::OptionalInfo opti_compilation_info;
|
||||
fe::PlatformInfoManager &inst = fe::PlatformInfoManager::Instance();
|
||||
if (inst.GetPlatformInfo(soc_version, platform_info, opti_compilation_info) != 0) {
|
||||
MS_LOG(ERROR) << "GetPlatformInfo failed.";
|
||||
return false;
|
||||
}
|
||||
// TODO(jjfeing)
|
||||
// auto ub_size = platform_info.ai_core_spec.ub_size;
|
||||
// auto data_size = GetDtypeNbyte(input_dtype);
|
||||
// auto ori_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelReduceSelector::CheckOriginInputShapeDimEqual(size_t support_dim_size) const {
|
||||
// Note: identify format not check
|
||||
return std::any_of(input_shape_.begin(), input_shape_.end(),
|
||||
[&support_dim_size](const auto &shape) { return (shape.size() != support_dim_size); });
|
||||
}
|
||||
|
||||
bool TbeKernelReduceSelector::CheckOriginInputShapeDimLess(size_t support_min_dim_size) const {
|
||||
// Note: identify format not check
|
||||
return std::any_of(input_shape_.begin(), input_shape_.end(),
|
||||
[&support_min_dim_size](const auto &shape) { return (shape.size() < support_min_dim_size); });
|
||||
}
|
||||
|
||||
bool TbeKernelReduceSelector::CheckReduceContainChanel(int64_t channel_index) const {
|
||||
// if axis is empty, means reduce all axes, return true;
|
||||
if (axis_.empty()) {
|
||||
return true;
|
||||
}
|
||||
// channel out size input size, return true;
|
||||
if (channel_index < 0 || input_shape_.at(kIndex0).size() <= LongToSize(channel_index)) {
|
||||
return true;
|
||||
}
|
||||
// any of elem contain channel, return true;
|
||||
return std::any_of(axis_.begin(), axis_.end(),
|
||||
[&channel_index](const int64_t &elem) { return (elem == channel_index); });
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TbeKernelReduceSelector {
|
||||
public:
|
||||
explicit TbeKernelReduceSelector(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {}
|
||||
~TbeKernelReduceSelector() = default;
|
||||
void GetSupportedFormatDType(SupportFormatDType *support_format_dtype);
|
||||
|
||||
private:
|
||||
void GetReduceNodeInfo();
|
||||
void GetCheckInfo();
|
||||
void GetReduceSupport5HD(SupportFormat *support_format) const;
|
||||
void GetReduceSupportNDC1HWC0(SupportFormat *support_format) const;
|
||||
void GetReduceSupportFracZ(SupportFormat *support_format) const;
|
||||
void GetReduceSupportC1HWNCoC0(SupportFormat *support_format) const;
|
||||
void GetReduceSupportFracZ3D(SupportFormat *support_format) const;
|
||||
void GetReduceSupportFracNZ(SupportFormat *support_format) const;
|
||||
void GetReduceAttrKeepDim();
|
||||
void FilterInvalidFormatDType(SupportFormatDType *support_format_dtype);
|
||||
bool CheckUBSizeEnable(const std::string &input_format, const std::string &input_dtype);
|
||||
[[nodiscard]] inline bool CheckOriginInputShapeDimEqual(size_t support_dim_size) const;
|
||||
[[nodiscard]] inline bool CheckOriginInputShapeDimLess(size_t support_min_dim_size) const;
|
||||
[[nodiscard]] inline bool CheckReduceContainChanel(int64_t channel_index) const;
|
||||
|
||||
CNodePtr cnode_ptr_;
|
||||
std::vector<ShapeVector> input_shape_{};
|
||||
std::vector<ShapeVector> output_shape_{};
|
||||
std::vector<int64_t> axis_{};
|
||||
bool keep_dims_ = false;
|
||||
// check info
|
||||
bool is_shape_4_dims_ = false;
|
||||
bool is_shape_5_dims_ = false;
|
||||
bool is_shape_less_2_dims_ = false;
|
||||
bool is_shape_less_4_dims_ = false;
|
||||
bool is_shape_less_5_dims_ = false;
|
||||
bool is_reduce_n_channel_ = false;
|
||||
bool is_reduce_c_channel_ = false;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_PATTERN_REDUCE_SELECTOR_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,11 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -56,7 +57,7 @@ bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
|
|||
for (size_t i = 0; i < input_num_; ++i) {
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&input_shape);
|
||||
input_shapes_.emplace_back(input_shape);
|
||||
(void)input_shapes_.emplace_back(input_shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,7 +65,7 @@ bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) {
|
|||
for (size_t i = 0; i < output_num_; ++i) {
|
||||
auto output = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, i);
|
||||
PadScalarShape(&output);
|
||||
output_shapes_.emplace_back(output);
|
||||
(void)output_shapes_.emplace_back(output);
|
||||
}
|
||||
AssignSupportFormat(kOpFormat_DEFAULT, support_format);
|
||||
return true;
|
||||
|
@ -85,7 +86,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_fo
|
|||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is4DShape(shape)) {
|
||||
return false;
|
||||
|
@ -93,7 +94,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_fo
|
|||
if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return false;
|
||||
}
|
||||
input_support_format.emplace_back(kOpFormat_NC1HWC0);
|
||||
(void)input_support_format.emplace_back(kOpFormat_NC1HWC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -167,7 +168,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
|
|||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (!Is4DShape(shape)) {
|
||||
return false;
|
||||
|
@ -175,7 +176,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
|
|||
if (shape[kChannelN] % kAlignmented16 != 0) {
|
||||
return false;
|
||||
}
|
||||
input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
|
||||
(void)input_support_format.emplace_back(kOpFormat_C1HWNCoC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -196,8 +197,8 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *supp
|
|||
input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0);
|
||||
}
|
||||
GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format);
|
||||
support_format->input_format.emplace_back(input_support_format);
|
||||
support_format->output_format.emplace_back(output_support_format);
|
||||
(void)support_format->input_format.emplace_back(input_support_format);
|
||||
(void)support_format->output_format.emplace_back(output_support_format);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -216,7 +217,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
|
|||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)input_support_format.emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
if (shape.size() < kShape2dDims) {
|
||||
return false;
|
||||
|
@ -225,7 +226,7 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support
|
|||
if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - SECOND_LAST] % kAlignmented16 != 0) {
|
||||
return false;
|
||||
}
|
||||
input_support_format.emplace_back(kOpFormat_FRAC_NZ);
|
||||
(void)input_support_format.emplace_back(kOpFormat_FRAC_NZ);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -269,13 +270,13 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *suppo
|
|||
if (HasScalarInput()) {
|
||||
for (const auto &shape : input_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
input_support_format.emplace_back(kOpFormat_NCDHW);
|
||||
(void)input_support_format.emplace_back(kOpFormat_NCDHW);
|
||||
} else if (!Is5DShape(shape)) {
|
||||
return false;
|
||||
} else if (shape[kChannelC] % kAlignmented16 != 0) {
|
||||
return false;
|
||||
} else {
|
||||
input_support_format.emplace_back(kOpFormat_NDC1HWC0);
|
||||
(void)input_support_format.emplace_back(kOpFormat_NDC1HWC0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -322,7 +323,7 @@ bool TbeKernelBroadCastSelecter::IsSameShape() const {
|
|||
void TbeKernelBroadCastSelecter::PadScalarShape(ShapeVector *shape) const {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->empty()) {
|
||||
shape->emplace_back(1);
|
||||
(void)shape->emplace_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -346,9 +347,9 @@ void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &suppo
|
|||
MS_EXCEPTION_IF_NULL(output_support_item);
|
||||
for (const auto &shape : output_shapes_) {
|
||||
if (IsScalarShape(shape)) {
|
||||
output_support_item->emplace_back(kOpFormat_DEFAULT);
|
||||
(void)output_support_item->emplace_back(kOpFormat_DEFAULT);
|
||||
} else {
|
||||
output_support_item->emplace_back(support_format);
|
||||
(void)output_support_item->emplace_back(support_format);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_BROADCAST_SELECTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,12 +15,14 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,13 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_REDUCE_SELECTER_H_
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class TbeKernelReduceSelecter {
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_convert_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
@ -28,6 +31,7 @@
|
|||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_property_checker.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_selector_creator.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
#include "include/common/utils/json_operation_utils.h"
|
||||
|
@ -41,6 +45,7 @@ constexpr auto kPrefixOutput = "output";
|
|||
constexpr char kParamTypeDynamic[] = "dynamic";
|
||||
constexpr char kParamTypeRequre[] = "required";
|
||||
constexpr char kParamTypeOptional[] = "optional";
|
||||
constexpr int64_t kDynamicInvalidNum = -1;
|
||||
|
||||
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
auto tbe_selector = TbeKernelSelect(kernel_node, kernel_info_list);
|
||||
|
@ -152,56 +157,83 @@ void TbeKernelSelect::GetKernelHashName() {
|
|||
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << cnode_ptr_->fullname_with_scope() << "]";
|
||||
}
|
||||
kernel_hash_name = json_creator->GetJsonName();
|
||||
return;
|
||||
}
|
||||
|
||||
void TbeKernelSelect::TbeMetadataInfoEx() {
|
||||
if (!check_cnode) {
|
||||
return;
|
||||
}
|
||||
|
||||
GetKernelHashName();
|
||||
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
|
||||
full_name_ = cnode_ptr_->fullname_with_scope();
|
||||
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||
// if op pattern is FormatAgnostic, can't use select cache
|
||||
bool skip_cache = (op_info_ptr->op_pattern() == kFormatAgnosticPattern);
|
||||
|
||||
auto iter = select_cache_.find(kernel_hash_name);
|
||||
if (iter != select_cache_.end() && !skip_cache) {
|
||||
for (auto &cache_info : iter->second) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
|
||||
kernel_info_list_->emplace_back(builder.Build());
|
||||
std::string new_select_process = common::GetEnv("XXX");
|
||||
if (new_select_process.empty()) {
|
||||
// Step1: Initialize, get node info
|
||||
if (!Initialize()) {
|
||||
MS_LOG(DEBUG) << "Initialize failed, node name: " << full_name_;
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node "
|
||||
<< cnode_ptr_->fullname_with_scope();
|
||||
return;
|
||||
}
|
||||
|
||||
if (op_info_ptr->is_dynamic_format()) {
|
||||
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
|
||||
// Step2: if kernel build info in cache, use cache return
|
||||
if (GetKernelBuildInfoFromCache()) {
|
||||
return;
|
||||
}
|
||||
// Step3:
|
||||
auto get_support_format_dtype_func = GetSelectorFunc(cnode_ptr_);
|
||||
if (!get_support_format_dtype_func) {
|
||||
MS_LOG(ERROR) << "Get selector func failed, node name: " << full_name_;
|
||||
return;
|
||||
}
|
||||
SupportFormatDType support_format_dtype;
|
||||
get_support_format_dtype_func(cnode_ptr_, &support_format_dtype);
|
||||
PrintSupportedFormatDtype(support_format_dtype);
|
||||
GenerateKernelBuildInfo(support_format_dtype);
|
||||
FilterInvalidKernelInfo();
|
||||
AddKernelBuildInfoToCache();
|
||||
} else {
|
||||
OpPattern pattern = op_info_ptr->op_pattern();
|
||||
if (pattern == kCommonPattern) {
|
||||
GetCommonPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kFormatAgnosticPattern) {
|
||||
GetAgnosticPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kBroadcastPattern) {
|
||||
GetBroadcastPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kReducePattern) {
|
||||
GetReducePatternKernelInfo(*op_info_ptr);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Warning: op pattern is invailed.";
|
||||
if (!check_cnode) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// check support
|
||||
FilterInVaildKernelInfo(*op_info_ptr);
|
||||
GetKernelHashName();
|
||||
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
|
||||
full_name_ = cnode_ptr_->fullname_with_scope();
|
||||
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||
// if op pattern is FormatAgnostic, can't use select cache
|
||||
bool skip_cache = (op_info_ptr->op_pattern() == kFormatAgnosticPattern);
|
||||
|
||||
for (auto &kernel_info : *kernel_info_list_) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(kernel_info);
|
||||
select_cache_[kernel_hash_name].emplace_back(builder.Build());
|
||||
auto iter = select_cache_.find(kernel_hash_name);
|
||||
if (iter != select_cache_.end() && !skip_cache) {
|
||||
for (auto &cache_info : iter->second) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
|
||||
(void)kernel_info_list_->emplace_back(builder.Build());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node "
|
||||
<< cnode_ptr_->fullname_with_scope();
|
||||
return;
|
||||
}
|
||||
|
||||
if (op_info_ptr->is_dynamic_format()) {
|
||||
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
|
||||
} else {
|
||||
OpPattern pattern = op_info_ptr->op_pattern();
|
||||
if (pattern == kCommonPattern) {
|
||||
GetCommonPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kFormatAgnosticPattern) {
|
||||
GetAgnosticPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kBroadcastPattern) {
|
||||
GetBroadcastPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kReducePattern) {
|
||||
GetReducePatternKernelInfo(*op_info_ptr);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Warning: op pattern is invailed.";
|
||||
}
|
||||
}
|
||||
// check support
|
||||
FilterInvalidKernelInfo();
|
||||
AddKernelBuildInfoToCache();
|
||||
}
|
||||
}
|
||||
|
||||
void TbeKernelSelect::AddKernelBuildInfoToCache() {
|
||||
if (kernel_info_list_->empty()) {
|
||||
select_cache_[kernel_hash_name] = {};
|
||||
return;
|
||||
}
|
||||
select_cache_[kernel_hash_name] = *kernel_info_list_;
|
||||
MS_LOG(INFO) << "Add select kernel cache " << kernel_hash_name << " from node " << cnode_ptr_->fullname_with_scope()
|
||||
<< ", cache size: " << select_cache_.size();
|
||||
}
|
||||
|
@ -248,7 +280,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
|||
builder.SetOutputsDeviceType(outputs_device_type);
|
||||
builder.SetOutputsFormat(outputs_format);
|
||||
builder.SetOutputsReshapeType(outputs_reshape_type);
|
||||
kernel_info_list_->emplace_back(builder.Build());
|
||||
(void)kernel_info_list_->emplace_back(builder.Build());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -272,8 +304,8 @@ void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) {
|
|||
SupportFormatItem output_item;
|
||||
input_item.assign(op_info.inputs_ptr().size(), format);
|
||||
output_item.assign(op_info.outputs_ptr().size(), format);
|
||||
support_format.input_format.emplace_back(input_item);
|
||||
support_format.output_format.emplace_back(output_item);
|
||||
(void)support_format.input_format.emplace_back(input_item);
|
||||
(void)support_format.output_format.emplace_back(output_item);
|
||||
OpInfo op_info_new;
|
||||
CreateNewOpInfo(op_info, support_format, &op_info_new);
|
||||
GetCommonPatternKernelInfo(op_info_new);
|
||||
|
@ -306,32 +338,31 @@ void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) {
|
|||
GetCommonPatternKernelInfo(op_info_new);
|
||||
}
|
||||
|
||||
void TbeKernelSelect::FilterInVaildKernelInfo(const OpInfo &op_info) {
|
||||
void TbeKernelSelect::FilterInvalidKernelInfo() {
|
||||
if (kernel_info_list_->empty()) {
|
||||
MS_LOG(INFO) << "Warning: get kernel build info failed. Skip check supported. Op name: " << full_name_;
|
||||
return;
|
||||
}
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
|
||||
auto dynamic_inputs = GetNodeDynamicInputs();
|
||||
auto need_check_supported = op_info.need_check_supported();
|
||||
for (auto iter = kernel_info_list_->begin(); iter != kernel_info_list_->end(); ++iter) {
|
||||
if (!FilterInVaildShape(iter, !dynamic_inputs.empty())) {
|
||||
for (const auto &kernel_build_info : *kernel_info_list_) {
|
||||
if (!FilterInvalidShape(kernel_build_info, !dynamic_inputs.empty())) {
|
||||
continue;
|
||||
}
|
||||
if (need_check_supported && !TbeCheckSupported(iter)) {
|
||||
if (!TbeCheckSupported(kernel_build_info)) {
|
||||
continue;
|
||||
}
|
||||
kernel_info_list.emplace_back(*iter);
|
||||
(void)kernel_info_list.emplace_back(kernel_build_info);
|
||||
}
|
||||
if (kernel_info_list.empty()) {
|
||||
MS_LOG(DEBUG) << "After tbe check supported, all valid AI CORE kernel infos were filtered out. Node:" << full_name_;
|
||||
}
|
||||
(*kernel_info_list_) = kernel_info_list;
|
||||
(*kernel_info_list_).swap(kernel_info_list);
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input) {
|
||||
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
|
||||
const auto &kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats();
|
||||
bool TbeKernelSelect::FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
const auto &kernel_build_info_inputs_format = kernel_build_info->GetAllInputFormats();
|
||||
// dynamic input just need to check first input, because other inputs copy from 1th input;
|
||||
auto iter_num =
|
||||
is_dynamic_input && !kernel_build_info_inputs_format.empty() ? 1 : kernel_build_info_inputs_format.size();
|
||||
|
@ -342,7 +373,7 @@ bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build
|
|||
return false;
|
||||
}
|
||||
}
|
||||
const auto &kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats();
|
||||
const auto &kernel_build_info_outputs_format = kernel_build_info->GetAllOutputFormats();
|
||||
for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) {
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(cnode_ptr_, j);
|
||||
const auto &format = kernel_build_info_outputs_format[j];
|
||||
|
@ -354,36 +385,81 @@ bool TbeKernelSelect::FilterInVaildShape(const KernelBuildInfoIter &kernel_build
|
|||
}
|
||||
|
||||
bool TbeKernelSelect::IsShapeMatchFormat(const ShapeVector &shape, const std::string &format) {
|
||||
// if format is default, it means support all format
|
||||
if (format == kOpFormat_DEFAULT) {
|
||||
return true;
|
||||
}
|
||||
static const std::set<std::string> kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
// if format is default, it remarkes support all format
|
||||
if (!IsOneOfFormat(format)) {
|
||||
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
|
||||
}
|
||||
// server not support format with C04 suffix
|
||||
if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) !=
|
||||
kServerNotSupportFormat.end()) {
|
||||
if (IsOneOfServerFormatC04(format)) {
|
||||
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
|
||||
return false;
|
||||
}
|
||||
if (format == kOpFormat_FRAC_NZ && shape.size() > kShape2dDims) {
|
||||
return true;
|
||||
// kOpFormat_FRAC_NZ >=2 || %16 == 0
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return (shape.size() >= kShape2dDims) ||
|
||||
(shape.size() == 1 && (shape[0] == 1 || (shape[0] % SizeToLong(kCubeSize) == 0)));
|
||||
}
|
||||
// RNN
|
||||
if (!IsShapeMatchFormatRNN(shape, format)) {
|
||||
return false;
|
||||
}
|
||||
// not support format:
|
||||
// 1 3d formats with shape size > 5
|
||||
if (IsOneOf3DFormat(format) && shape.size() > kShape5dDims) {
|
||||
// 3D formats with shape size > 5
|
||||
if (IsOneOf3DFormat(format)) {
|
||||
return shape.size() <= kShape5dDims;
|
||||
}
|
||||
// check format is valid.
|
||||
if (!IsOneOfFormat(format)) {
|
||||
MS_LOG(WARNING) << "Got the unknown format " << format;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter) {
|
||||
MS_EXCEPTION_IF_NULL((*kernel_build_info_iter));
|
||||
bool TbeKernelSelect::IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format) {
|
||||
// kOpFormat_FRACTAL_ZN_RNN >=2
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN || format == kOpFormat_ND_RNN_BIAS) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode_ptr_) ||
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode_ptr_)) {
|
||||
return false;
|
||||
}
|
||||
auto input_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode_ptr_, kAttrInputSize);
|
||||
auto hidden_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode_ptr_, kAttrHiddenSize);
|
||||
|
||||
// kOpFormat_FRACTAL_ZN_RNN >=2
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||
if (shape.size() < kDim2) {
|
||||
return false;
|
||||
}
|
||||
auto last_but_one_dim = shape[shape.size() - kDim2];
|
||||
if ((last_but_one_dim != abstract::Shape::SHP_ANY) && (last_but_one_dim != input_size) &&
|
||||
(last_but_one_dim != hidden_size) && (last_but_one_dim != input_size + hidden_size)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// kOpFormat_ND_RNN_BIAS shape not empty()
|
||||
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||
if (shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto last_dim = shape[shape.size() - kDim1];
|
||||
if (last_dim != abstract::Shape::SHP_ANY && last_dim % hidden_size != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info) {
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode_ptr_);
|
||||
if (!op_info->need_check_supported()) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
// replace kernel_info with current kernel info
|
||||
auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cnode_ptr_.get());
|
||||
auto &build_manager = kernel::ascend::TbeKernelCompileManager::GetInstance();
|
||||
auto ret =
|
||||
HostCheck::CheckValidDeviceShape(cnode_ptr_) && build_manager.TbeOpCheckSupported(cnode_ptr_, &kernel_json);
|
||||
|
@ -408,9 +484,12 @@ std::vector<int64_t> TbeKernelSelect::GetNodeDynamicInputs() {
|
|||
// get dynamic inputs
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode_ptr_);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
std::vector<int64_t> dyn_input_sizes = {};
|
||||
if (primitive->HasAttr(kAttrDynInputSizes)) {
|
||||
dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr(kAttrDynInputSizes));
|
||||
if (dyn_input_sizes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Now dynamic input only support one in ascend.";
|
||||
}
|
||||
}
|
||||
return dyn_input_sizes;
|
||||
}
|
||||
|
@ -447,10 +526,10 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
|
|||
}
|
||||
int64_t dynamic_input_size = dyn_input_sizes[dynamic_input_index];
|
||||
for (int64_t i = 0; i < dynamic_input_size; ++i) {
|
||||
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
formats->emplace_back(kernel_build_info_format);
|
||||
reshape_types->emplace_back(reshape_type);
|
||||
value_depends->emplace_back(value_depend);
|
||||
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
(void)formats->emplace_back(kernel_build_info_format);
|
||||
(void)reshape_types->emplace_back(reshape_type);
|
||||
(void)value_depends->emplace_back(value_depend);
|
||||
}
|
||||
dynamic_input_index++;
|
||||
real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, LongToSize(dynamic_input_size));
|
||||
|
@ -459,19 +538,19 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind
|
|||
MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output.";
|
||||
}
|
||||
for (size_t i = 0; i < real_io_tensor_num; ++i) {
|
||||
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
formats->emplace_back(kernel_build_info_format);
|
||||
reshape_types->emplace_back(reshape_type);
|
||||
value_depends->emplace_back(value_depend);
|
||||
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
(void)formats->emplace_back(kernel_build_info_format);
|
||||
(void)reshape_types->emplace_back(reshape_type);
|
||||
(void)value_depends->emplace_back(value_depend);
|
||||
}
|
||||
real_io_tensor_index = SizetAddWithOverflowCheck(real_io_tensor_index, real_io_tensor_num);
|
||||
}
|
||||
} else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) {
|
||||
// require or optional io
|
||||
device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
formats->emplace_back(kernel_build_info_format);
|
||||
reshape_types->emplace_back(reshape_type);
|
||||
value_depends->emplace_back(value_depend);
|
||||
(void)device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype));
|
||||
(void)formats->emplace_back(kernel_build_info_format);
|
||||
(void)reshape_types->emplace_back(reshape_type);
|
||||
(void)value_depends->emplace_back(value_depend);
|
||||
real_io_tensor_index++;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type;
|
||||
|
@ -511,7 +590,7 @@ void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io
|
|||
for (const auto &formats : support_format_item) {
|
||||
auto format = formats.at(index);
|
||||
for (size_t j = 0; j < dtype.size(); ++j) {
|
||||
format_new.emplace_back(format);
|
||||
(void)format_new.emplace_back(format);
|
||||
}
|
||||
}
|
||||
op_io_info_new->set_formats(format_new);
|
||||
|
@ -537,7 +616,7 @@ std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_se
|
|||
if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) {
|
||||
obj = kDynamicFormatMap.at(obj);
|
||||
}
|
||||
ret.emplace_back(obj);
|
||||
(void)ret.emplace_back(obj);
|
||||
begin = op_select_tmp.find_first_not_of(space, sep_pos + 1);
|
||||
sep_pos = op_select_tmp.find(sep, begin);
|
||||
}
|
||||
|
@ -620,7 +699,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
|
|||
select_input.dtypes = SplitStrToVec(input_dtype_item);
|
||||
std::string input_format_item = item.value().at(kFormat);
|
||||
select_input.formats = SplitStrToVec(input_format_item);
|
||||
inputs.emplace_back(select_input);
|
||||
(void)inputs.emplace_back(select_input);
|
||||
} else {
|
||||
SelectOpIOInfo select_output;
|
||||
select_output.name = item.value().at(kName);
|
||||
|
@ -628,7 +707,7 @@ void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info,
|
|||
select_output.dtypes = SplitStrToVec(input_dtype_item);
|
||||
std::string input_format_item = item.value().at(kFormat);
|
||||
select_output.formats = SplitStrToVec(input_format_item);
|
||||
outputs.emplace_back(select_output);
|
||||
(void)outputs.emplace_back(select_output);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -671,25 +750,159 @@ void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io
|
|||
op_io_info_new->set_formats(support_format);
|
||||
}
|
||||
|
||||
void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) {
|
||||
if (support_format.input_format.size() != support_format.output_format.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output("
|
||||
<< support_format.output_format.size() << ") size not match.";
|
||||
void TbeKernelSelect::PrintSupportedFormatDtype(const SupportFormatDType &support_format_dtype) {
|
||||
MS_LOG(DEBUG) << "full_name: " << full_name_;
|
||||
MS_LOG(DEBUG) << "==============input dtype=============";
|
||||
for (auto input : support_format_dtype.input_dtypes) {
|
||||
std::stringstream ss;
|
||||
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
|
||||
}
|
||||
for (size_t i = 0; i < support_format.input_format.size(); ++i) {
|
||||
auto input_items = support_format.input_format.at(i);
|
||||
auto output_items = support_format.output_format.at(i);
|
||||
std::string print_str = "[";
|
||||
for (const auto &input : input_items) {
|
||||
(void)print_str.append(input);
|
||||
(void)print_str.append(", ");
|
||||
}
|
||||
(void)print_str.append("] -->");
|
||||
for (const auto &output : output_items) {
|
||||
(void)print_str.append(output);
|
||||
(void)print_str.append(", ");
|
||||
}
|
||||
MS_LOG(INFO) << "Support format: " << print_str;
|
||||
MS_LOG(DEBUG) << "==============input format=============";
|
||||
for (auto input : support_format_dtype.input_formats) {
|
||||
std::stringstream ss;
|
||||
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
|
||||
}
|
||||
MS_LOG(DEBUG) << "==============output dtype=============";
|
||||
for (auto input : support_format_dtype.output_dtypes) {
|
||||
std::stringstream ss;
|
||||
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
|
||||
}
|
||||
MS_LOG(DEBUG) << "==============output format=============";
|
||||
for (auto input : support_format_dtype.output_formats) {
|
||||
std::stringstream ss;
|
||||
copy(input.begin(), input.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||
MS_LOG(DEBUG) << "[ " << ss.str() << " ]";
|
||||
}
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::Initialize() {
|
||||
// Init 1.op_name, 2.full_name, 3.op_info, 4.kernel_json, 5.kernel_hash_name
|
||||
node_name_ = common::AnfAlgo::GetCNodeName(cnode_ptr_);
|
||||
full_name_ = cnode_ptr_->fullname_with_scope();
|
||||
op_info_ = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
|
||||
if (!op_info_) {
|
||||
return false;
|
||||
}
|
||||
auto json_creator = std::make_shared<SelectTbeJsonCreator>();
|
||||
MS_EXCEPTION_IF_NULL(json_creator);
|
||||
auto ret = json_creator->GenJson(cnode_ptr_, &kernel_json);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Gen node hash failed. [" << cnode_ptr_->fullname_with_scope() << "]";
|
||||
}
|
||||
kernel_hash_name = json_creator->GetJsonName();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::GetKernelBuildInfoFromCache() {
|
||||
// Note: kFormatAgnosticPattern need select, like cast ...
|
||||
if (op_info_->op_pattern() == kFormatAgnosticPattern) {
|
||||
return false;
|
||||
}
|
||||
auto iter = select_cache_.find(kernel_hash_name);
|
||||
if (iter == select_cache_.end()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &cache_info : iter->second) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder(cache_info);
|
||||
(void)kernel_info_list_->emplace_back(builder.Build());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Select kernel cache hit " << kernel_hash_name << " for node " << cnode_ptr_->fullname_with_scope();
|
||||
return true;
|
||||
}
|
||||
|
||||
void TbeKernelSelect::GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype) {
|
||||
auto dyn_input_sizes = GetNodeDynamicInputs();
|
||||
// get real input/output num
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||
auto op_info_input_num = support_format_dtype.input_dtypes.size();
|
||||
auto op_info_output_num = support_format_dtype.output_dtypes.size();
|
||||
if (op_info_output_num == 0 && op_info_input_num == 0) {
|
||||
MS_LOG(EXCEPTION) << "input and output is null, please check, " << full_name_;
|
||||
}
|
||||
|
||||
auto select_support_num = support_format_dtype.output_dtypes.at(0).size();
|
||||
auto dynamic_input_num = dyn_input_sizes.empty() ? kDynamicInvalidNum : dyn_input_sizes.at(0);
|
||||
for (size_t support_index = 0; support_index < select_support_num; ++support_index) {
|
||||
KernelBuildInfoItem input_kernel_build_info;
|
||||
KernelBuildInfoItem output_kernel_build_info;
|
||||
// size_t io_index = 0;
|
||||
// size_t real_put_index = 0;
|
||||
for (size_t io_index = 0, real_put_index = 0; real_put_index < real_input_num && io_index < op_info_input_num;) {
|
||||
auto op_io_info = op_info_->inputs_ptr().at(io_index);
|
||||
auto support_dtype = support_format_dtype.input_dtypes.at(io_index).at(support_index);
|
||||
auto support_format = support_format_dtype.input_formats.at(io_index).at(support_index);
|
||||
ConstructIOKernelBuildInfo(op_io_info, support_dtype, support_format, dynamic_input_num, &input_kernel_build_info,
|
||||
&io_index, &real_put_index);
|
||||
}
|
||||
for (size_t io_index = 0, real_put_index = 0; real_put_index < real_output_num && io_index < op_info_output_num;) {
|
||||
auto op_io_info = op_info_->outputs_ptr().at(io_index);
|
||||
auto support_dtype = support_format_dtype.output_dtypes.at(io_index).at(support_index);
|
||||
auto support_format = support_format_dtype.output_formats.at(io_index).at(support_index);
|
||||
int64_t dynamic_output_num = dynamic_input_num;
|
||||
if (op_io_info->param_type() == kParamTypeDynamic) {
|
||||
if (op_info_->outputs_ptr().size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic output num only support 1, output name: " << op_io_info->name()
|
||||
<< ", node name: " << full_name_;
|
||||
}
|
||||
dynamic_output_num = SizeToLong(real_output_num);
|
||||
}
|
||||
ConstructIOKernelBuildInfo(op_io_info, support_dtype, support_format, dynamic_output_num,
|
||||
&output_kernel_build_info, &io_index, &real_put_index);
|
||||
}
|
||||
ConstructKernelBuildInfo(input_kernel_build_info, output_kernel_build_info);
|
||||
}
|
||||
}
|
||||
|
||||
void TbeKernelSelect::ConstructKernelBuildInfo(const KernelBuildInfoItem &input_kernel_build_info,
|
||||
const KernelBuildInfoItem &output_kernel_build_info) {
|
||||
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||
builder.SetProcessor(AICORE);
|
||||
std::string fusion_name = op_info_->fusion_type();
|
||||
auto fusion_type = GetFusionTypeByName(fusion_name);
|
||||
if (fusion_type != UNKNOWN_FUSION_TYPE) {
|
||||
builder.SetFusionType(fusion_type);
|
||||
}
|
||||
builder.SetOpPattern(op_info_->op_pattern());
|
||||
builder.SetKernelType(TBE_KERNEL);
|
||||
builder.SetInputsDeviceType(input_kernel_build_info.device_types);
|
||||
builder.SetInputsFormat(input_kernel_build_info.formats);
|
||||
builder.SetInputsReshapeType(input_kernel_build_info.reshape_types);
|
||||
builder.SetInputsValueDepend(input_kernel_build_info.value_depends);
|
||||
builder.SetOutputsDeviceType(output_kernel_build_info.device_types);
|
||||
builder.SetOutputsFormat(output_kernel_build_info.formats);
|
||||
builder.SetOutputsReshapeType(output_kernel_build_info.reshape_types);
|
||||
(void)kernel_info_list_->emplace_back(builder.Build());
|
||||
}
|
||||
|
||||
void TbeKernelSelect::ConstructIOKernelBuildInfo(const OpIOInfoPtr &op_io_info, const std::string &support_dtype,
|
||||
const std::string &support_format, int64_t dynamic_num,
|
||||
KernelBuildInfoItem *kernel_build_info_item, size_t *io_index,
|
||||
size_t *real_put_index) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_item);
|
||||
MS_EXCEPTION_IF_NULL(io_index);
|
||||
MS_EXCEPTION_IF_NULL(real_put_index);
|
||||
if (op_io_info->param_type() == kParamTypeDynamic) {
|
||||
if (dynamic_num == kDynamicInvalidNum) {
|
||||
MS_LOG(EXCEPTION) << "Get node dynamic inputs num failed, node name: " << full_name_;
|
||||
}
|
||||
for (int64_t i = 0; i < dynamic_num; ++i) {
|
||||
(void)kernel_build_info_item->formats.emplace_back(support_format);
|
||||
(void)kernel_build_info_item->device_types.emplace_back(tbe::DtypeToTypeId(support_dtype));
|
||||
(void)kernel_build_info_item->reshape_types.emplace_back(op_io_info->reshape_type());
|
||||
(void)kernel_build_info_item->value_depends.emplace_back(op_io_info->value_depend());
|
||||
}
|
||||
(*real_put_index) += LongToSize(dynamic_num);
|
||||
} else {
|
||||
(void)kernel_build_info_item->formats.emplace_back(support_format);
|
||||
(void)kernel_build_info_item->device_types.emplace_back(tbe::DtypeToTypeId(support_dtype));
|
||||
(void)kernel_build_info_item->reshape_types.emplace_back(op_io_info->reshape_type());
|
||||
(void)kernel_build_info_item->value_depends.emplace_back(op_io_info->value_depend());
|
||||
(*real_put_index) += 1;
|
||||
}
|
||||
(*io_index) += 1;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -22,17 +22,15 @@
|
|||
#include <memory>
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace mindspore::kernel {
|
||||
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||
bool TbeCheckIsSupportedSpec(const CNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
|
||||
bool TbeCheckIsSupportedAny(const CNodePtr &kernel_node);
|
||||
|
||||
class TbeKernelSelect {
|
||||
using OpInfoPtr = std::shared_ptr<OpInfo>;
|
||||
using KernelBuildInfoIter = std::vector<std::shared_ptr<KernelBuildInfo>>::iterator;
|
||||
|
||||
public:
|
||||
TbeKernelSelect(CNodePtr kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||
|
@ -47,10 +45,11 @@ class TbeKernelSelect {
|
|||
void GetAgnosticPatternKernelInfo(const OpInfo &op_info);
|
||||
void GetBroadcastPatternKernelInfo(const OpInfo &op_info);
|
||||
void GetReducePatternKernelInfo(const OpInfo &op_info);
|
||||
void FilterInVaildKernelInfo(const OpInfo &op_info);
|
||||
bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter, bool is_dynamic_input);
|
||||
static bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
|
||||
bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter);
|
||||
void FilterInvalidKernelInfo();
|
||||
bool FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input);
|
||||
bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
|
||||
bool IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format);
|
||||
bool TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info);
|
||||
static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder);
|
||||
std::vector<int64_t> GetNodeDynamicInputs();
|
||||
bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num,
|
||||
|
@ -71,8 +70,6 @@ class TbeKernelSelect {
|
|||
void GetKernelHashName();
|
||||
bool CheckCNode();
|
||||
|
||||
static void PrintSupportedFormat(const SupportFormat &support_format);
|
||||
|
||||
CNodePtr cnode_ptr_;
|
||||
std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list_;
|
||||
std::string node_name_;
|
||||
|
@ -81,8 +78,23 @@ class TbeKernelSelect {
|
|||
std::string kernel_hash_name;
|
||||
bool check_cnode;
|
||||
inline static mindspore::HashMap<std::string, std::vector<std::shared_ptr<KernelBuildInfo>>> select_cache_ = {};
|
||||
|
||||
private:
|
||||
bool Initialize();
|
||||
bool GetKernelBuildInfoFromCache();
|
||||
void GenerateKernelBuildInfo(const SupportFormatDType &support_format_dtype);
|
||||
void ConstructIOKernelBuildInfo(const OpIOInfoPtr &op_io_info, const std::string &support_dtype,
|
||||
const std::string &support_format, int64_t dynamic_num,
|
||||
KernelBuildInfoItem *kernel_build_info_item, size_t *io_index,
|
||||
size_t *real_put_index) const;
|
||||
OpInfoPtr op_info_ = nullptr;
|
||||
|
||||
void ConstructKernelBuildInfo(const KernelBuildInfoItem &input_kernel_build_info,
|
||||
const KernelBuildInfoItem &output_kernel_build_info);
|
||||
void AddKernelBuildInfoToCache();
|
||||
|
||||
void PrintSupportedFormatDtype(const SupportFormatDType &support_format_dtype);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_TBE_KERNEL_SELECT_H
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -0,0 +1,229 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "base/base.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
constexpr size_t kNcdhwShapeSize = 5;
|
||||
|
||||
bool CheckValidInputAndHiddenSize(const AnfNodePtr &node) {
|
||||
if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
return param->input_size() > 0 && param->hidden_size() > 0;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return common::AnfAlgo::HasNodeAttr(kAttrInputSize, cnode) && common::AnfAlgo::HasNodeAttr(kAttrHiddenSize, cnode);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HostCheck::CheckValidDeviceShape(const AnfNodePtr &node) {
|
||||
size_t real_input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < real_input_num; i++) {
|
||||
auto format = AnfAlgo::GetInputFormat(node, i);
|
||||
if (!CheckValidInOutDeviceShape(node, i, false, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check input device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < real_output_num; i++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(node, i);
|
||||
if (!CheckValidInOutDeviceShape(node, i, true, format)) {
|
||||
MS_LOG(WARNING) << "TBE Host check output device shape failed, node:" << node->fullname_with_scope()
|
||||
<< ", format:" << format;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> HostCheck::GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto shape = is_output ? common::AnfAlgo::GetOutputDetailShape(node, index)
|
||||
: common::AnfAlgo::GetPrevNodeOutputDetailShape(node, index);
|
||||
std::vector<int64_t> infer_shape;
|
||||
if (shape->isa<abstract::Shape>()) {
|
||||
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
infer_shape = shape_ptr->shape();
|
||||
}
|
||||
if (infer_shape.empty()) {
|
||||
return infer_shape;
|
||||
}
|
||||
|
||||
if (trans::IsNeedPadding(format, infer_shape.size())) {
|
||||
auto reshape_type =
|
||||
is_output ? AnfAlgo::GetOutputReshapeType(node, index) : AnfAlgo::GetInputReshapeType(node, index);
|
||||
infer_shape = trans::PaddingShape(infer_shape, format, reshape_type, node);
|
||||
}
|
||||
|
||||
auto temp_shape = infer_shape;
|
||||
if (!IsOneOfNoPaddingFormat(format) && format != kOpFormat_FRACTAL_ZN_LSTM && infer_shape.size() < kShape4dDims &&
|
||||
!IsOneOf3DFormat(format)) {
|
||||
MS_LOG(DEBUG) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
||||
temp_shape = trans::PaddingShapeTo4dDefault(infer_shape, node);
|
||||
}
|
||||
if (infer_shape.size() != kNcdhwShapeSize && IsOneOf3DFormat(format)) {
|
||||
temp_shape = trans::PaddingShapeTo5dDefault(infer_shape, node);
|
||||
}
|
||||
return temp_shape;
|
||||
}
|
||||
|
||||
bool HostCheck::CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format) {
|
||||
auto infer_shape = GetFinalInferShape(node, index, is_output, format);
|
||||
if (infer_shape.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::set<std::string> check_4D_format = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_FRAC_Z,
|
||||
kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0, kOpFormat_FRACTAL_Z_C04,
|
||||
kOpFormat_NC1HWC0_C04};
|
||||
std::set<std::string> check_5D_format = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
if (check_4D_format.find(format) != check_4D_format.end()) {
|
||||
return infer_shape.size() == kShape4dDims;
|
||||
}
|
||||
if (check_5D_format.find(format) != check_5D_format.end()) {
|
||||
return infer_shape.size() == kShape5dDims;
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRAC_NZ) {
|
||||
return infer_shape.size() >= kShape2dDims ||
|
||||
(infer_shape.size() == 1 && (infer_shape[0] == 1 || (infer_shape[0] % SizeToLong(kCubeSize) == 0)));
|
||||
}
|
||||
|
||||
if (format == kOpFormat_FRACTAL_ZN_RNN) {
|
||||
return infer_shape.size() >= kShape2dDims && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
|
||||
if (format == kOpFormat_ND_RNN_BIAS) {
|
||||
return infer_shape.size() > 0 && CheckValidInputAndHiddenSize(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(support_format);
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
// Note: in reduce/broadcast case: dynamic/optional input/output not care.
|
||||
if (op_info->inputs_ptr().size() != input_num || op_info->outputs_ptr().size() != output_num) {
|
||||
MS_LOG(WARNING) << "Input/output maybe have optional or dynamic io, input num:" << input_num
|
||||
<< ", output_num:" << output_num << ", op_info->inputs size: " << op_info->inputs_ptr().size()
|
||||
<< ", op_info->output size: " << op_info->outputs_ptr().size();
|
||||
}
|
||||
SupportFormatItem input_item(input_num, kOpFormat_DEFAULT);
|
||||
(void)support_format->input_format.emplace_back(input_item);
|
||||
SupportFormatItem output_item(output_num, kOpFormat_DEFAULT);
|
||||
(void)support_format->output_format.emplace_back(output_item);
|
||||
}
|
||||
|
||||
void PadScalarShape(ShapeVector *shape) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (shape->empty()) {
|
||||
(void)shape->emplace_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateSupportFormat(const std::string &support_input_format, size_t input_num,
|
||||
const std::string &support_output_format, size_t output_num, SupportFormat *support_format) {
|
||||
SupportFormatItem input_item(input_num, support_input_format);
|
||||
(void)support_format->input_format.emplace_back(input_item);
|
||||
SupportFormatItem output_item(output_num, support_output_format);
|
||||
(void)support_format->output_format.emplace_back(output_item);
|
||||
}
|
||||
|
||||
void ConstructSupportDTypes(const std::vector<OpIOInfoPtr> &puts, size_t format_size,
|
||||
std::vector<SupportDTypeItem> *support_dtypes) {
|
||||
MS_EXCEPTION_IF_NULL(support_dtypes);
|
||||
for (const auto &put : puts) {
|
||||
MS_EXCEPTION_IF_NULL(put);
|
||||
auto dtypes = put->dtypes();
|
||||
SupportDTypeItem support_dtype_item = {};
|
||||
for (size_t i = 0; i < format_size; ++i) {
|
||||
(void)support_dtype_item.insert(support_dtype_item.cbegin(), dtypes.cbegin(), dtypes.cend());
|
||||
}
|
||||
(void)support_dtypes->emplace_back(support_dtype_item);
|
||||
}
|
||||
}
|
||||
|
||||
void ConstructSupportFormats(size_t put_size, const std::vector<SupportFormatItem> &support_format, size_t type_size,
|
||||
std::vector<SupportFormatItem> *support_formats) {
|
||||
MS_EXCEPTION_IF_NULL(support_formats);
|
||||
for (size_t i = 0; i < put_size; ++i) {
|
||||
SupportFormatItem support_format_item = {};
|
||||
for (const auto &formats : support_format) {
|
||||
for (size_t j = 0; j < type_size; ++j) {
|
||||
(void)support_format_item.emplace_back(formats.at(i));
|
||||
}
|
||||
}
|
||||
(void)support_formats->emplace_back(support_format_item);
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateSupportFormatDType(const OpInfoPtr &op_info, const SupportFormat &support_format, bool is_dynamic_shape,
|
||||
SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
if (op_info->inputs_ptr().size() != support_format.input_format.at(0).size()) {
|
||||
MS_LOG(WARNING) << "Op name: " << op_info->op_name()
|
||||
<< " has optional input, but in the graph, this input not exist."
|
||||
<< "op info input num: " << op_info->inputs_ptr().size()
|
||||
<< "graph node input num: " << support_format.input_format.at(0).size();
|
||||
}
|
||||
// TODO(jjfeing) usr format as dynamic format
|
||||
auto type_size = op_info->outputs_ptr().at(0)->dtypes().size();
|
||||
auto format_size = support_format.input_format.size();
|
||||
auto input_size = op_info->inputs_ptr().size();
|
||||
auto output_size = op_info->outputs_ptr().size();
|
||||
ConstructSupportDTypes(op_info->inputs_ptr(), format_size, &support_format_dtype->input_dtypes);
|
||||
ConstructSupportDTypes(op_info->outputs_ptr(), format_size, &support_format_dtype->output_dtypes);
|
||||
ConstructSupportFormats(input_size, support_format.input_format, type_size, &support_format_dtype->input_formats);
|
||||
ConstructSupportFormats(output_size, support_format.output_format, type_size, &support_format_dtype->output_formats);
|
||||
if (support_format_dtype->input_dtypes.size() != op_info->inputs_ptr().size() ||
|
||||
support_format_dtype->output_dtypes.size() != op_info->outputs_ptr().size() ||
|
||||
support_format_dtype->output_dtypes.at(0).size() != support_format_dtype->output_formats.at(0).size()) {
|
||||
MS_LOG(ERROR) << "GenerateSupportFormatDType failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateSupportFormatDType(const CNodePtr &cnode, const SupportFormat &support_format,
|
||||
SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(support_format_dtype);
|
||||
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(cnode);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode);
|
||||
GenerateSupportFormatDType(op_info, support_format, is_dynamic_shape, support_format_dtype);
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using SupportFormatItem = std::vector<std::string>;
|
||||
struct SupportFormat {
|
||||
std::vector<SupportFormatItem> input_format;
|
||||
std::vector<SupportFormatItem> output_format;
|
||||
};
|
||||
|
||||
using SupportDTypeItem = std::vector<std::string>;
|
||||
struct SupportFormatDType {
|
||||
std::vector<SupportDTypeItem> input_dtypes;
|
||||
std::vector<SupportFormatItem> input_formats;
|
||||
std::vector<SupportDTypeItem> output_dtypes;
|
||||
std::vector<SupportFormatItem> output_formats;
|
||||
};
|
||||
|
||||
struct KernelBuildInfoItem {
|
||||
std::vector<std::string> formats;
|
||||
std::vector<TypeId> device_types;
|
||||
std::vector<std::string> reshape_types;
|
||||
std::vector<std::string> value_depends;
|
||||
};
|
||||
class HostCheck {
|
||||
public:
|
||||
HostCheck() = default;
|
||||
~HostCheck() = default;
|
||||
static bool CheckValidDeviceShape(const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
static bool CheckValidInOutDeviceShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
static std::vector<int64_t> GetFinalInferShape(const AnfNodePtr &node, size_t index, bool is_output,
|
||||
const std::string &format);
|
||||
};
|
||||
|
||||
void GetSupportOriFormat(const CNodePtr &cnode, SupportFormat *support_format);
|
||||
|
||||
void PadScalarShape(ShapeVector *shape);
|
||||
|
||||
void GenerateSupportFormat(const std::string &support_input_format, size_t input_num,
|
||||
const std::string &support_output_format, size_t output_num, SupportFormat *support_format);
|
||||
|
||||
void ConstructSupportDTypes(const std::vector<OpIOInfoPtr> &puts, size_t format_size,
|
||||
std::vector<SupportDTypeItem> *support_dtypes);
|
||||
|
||||
void ConstructSupportFormats(size_t put_size, const std::vector<SupportFormatItem> &support_format, size_t type_size,
|
||||
std::vector<SupportFormatItem> *support_formats);
|
||||
|
||||
void GenerateSupportFormatDType(const OpInfoPtr &op_info, const SupportFormat &support_format, bool is_dynamic_shape,
|
||||
SupportFormatDType *support_format_dtype);
|
||||
void GenerateSupportFormatDType(const CNodePtr &cnode, const SupportFormat &support_format,
|
||||
SupportFormatDType *support_format_dtype);
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_KERNEL_TBE_SELECT_UTILS_H_
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_selector_creator.h"
|
||||
|
||||
#include <map>
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/agnostic_selector/tbe_kernel_agnostic_selector.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/common_selector/tbe_kernel_common_selector.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/dynamic_selector/tbe_kernel_dynamic_selector.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_broadcast_selector.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/pattern_selector/tbe_kernel_reduce_selector.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace {
|
||||
void GetCommonSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto selector = TbeKernelCommonSelector(cnode);
|
||||
selector.GetSupportedFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
void GetAgnosticSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto selector = TbeKernelAgnosticSelector(cnode);
|
||||
selector.GetSupportedFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
void GetDynamicSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto selector = TbeKernelDynamicSelector(cnode);
|
||||
selector.GetSupportedFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
void GetBroadcastSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto selector = TbeKernelBroadcastSelector(cnode);
|
||||
selector.GetSupportedFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
void GetReduceSupportedFormatDType(const CNodePtr &cnode, SupportFormatDType *support_format_dtype) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto selector = TbeKernelReduceSelector(cnode);
|
||||
selector.GetSupportedFormatDType(support_format_dtype);
|
||||
}
|
||||
|
||||
const std::map<OpPattern, GetSupportedFormatDTypeFunc> selector_funcs = {
|
||||
{kCommonPattern, GetCommonSupportedFormatDType},
|
||||
{kFormatAgnosticPattern, GetAgnosticSupportedFormatDType},
|
||||
{kBroadcastPattern, GetBroadcastSupportedFormatDType},
|
||||
{kReducePattern, GetReduceSupportedFormatDType},
|
||||
{kDynamicFormatPattern, GetDynamicSupportedFormatDType}};
|
||||
} // namespace
|
||||
GetSupportedFormatDTypeFunc GetSelectorFunc(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(cnode);
|
||||
if (!op_info) {
|
||||
return nullptr;
|
||||
}
|
||||
auto pattern = op_info->op_pattern();
|
||||
auto iter = selector_funcs.find(pattern);
|
||||
return iter == selector_funcs.end() ? nullptr : iter->second;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_
|
||||
#include <functional>
|
||||
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_select_utils.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using GetSupportedFormatDTypeFunc =
|
||||
std::function<void(const CNodePtr &cnode, SupportFormatDType *support_format_dtype)>;
|
||||
GetSupportedFormatDTypeFunc GetSelectorFunc(const CNodePtr &cnode);
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_SELECT_TBE_SELECTOR_CREATER_
|
|
@ -322,10 +322,6 @@ AnfNodePtr AddTransOpNodeToGraphWithFormat(const FuncGraphPtr &func_graph, const
|
|||
|
||||
trans_node = reshape_node;
|
||||
}
|
||||
if (trans_opname == prim::kPrimTransDataRNN->name()) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrHiddenSize, node, trans_data);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrInputSize, node, trans_data);
|
||||
}
|
||||
if (spec_format == kOpFormat_FRAC_Z && groups != 1 &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrFracZGroup, trans_data->cast<CNodePtr>())) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(groups), trans_data);
|
||||
|
@ -417,6 +413,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|||
common::AnfAlgo::SetNodeAttr(kAttrGroups, MakeValue(fracz_group), trans_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFracZGroup, MakeValue(fracz_group), trans_node);
|
||||
}
|
||||
} else if (op_name == prim::kPrimTransDataRNN->name()) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrHiddenSize, orig_node, trans_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrInputSize, orig_node, trans_node);
|
||||
}
|
||||
if (is_dynamic_shape) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), trans_node);
|
||||
|
|
|
@ -186,29 +186,22 @@ bool IsOneOfHWSpecialFormat(const std::string &format) {
|
|||
}
|
||||
|
||||
bool IsOneOfFormat(const std::string &format) {
|
||||
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT,
|
||||
kOpFormat_NC1KHKWHWC0,
|
||||
kOpFormat_ND,
|
||||
kOpFormat_NCHW,
|
||||
kOpFormat_NHWC,
|
||||
kOpFormat_HWCN,
|
||||
kOpFormat_NC1HWC0,
|
||||
kOpFormat_FRAC_Z,
|
||||
kOpFormat_C1HWNCoC0,
|
||||
kOpFormat_FRAC_NZ,
|
||||
kOpFormat_NC1HWC0_C04,
|
||||
kOpFormat_FRACTAL_Z_C04,
|
||||
kOpFormat_NDHWC,
|
||||
kOpFormat_FRACTAL_ZN_LSTM,
|
||||
kOpFormat_FRACTAL_ZN_RNN,
|
||||
kOpFormat_ND_RNN_BIAS,
|
||||
kOpFormat_NDC1HWC0,
|
||||
kOpFormat_NCDHW,
|
||||
kOpFormat_FRACTAL_Z_3D,
|
||||
kOpFormat_DHWNC,
|
||||
kOpFormat_DHWCN};
|
||||
const std::set<std::string> kOpFormatList = {
|
||||
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
|
||||
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
|
||||
kOpFormat_CHWN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z,
|
||||
kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04,
|
||||
kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM,
|
||||
kOpFormat_FRACTAL_ZN_RNN, kOpFormat_ND_RNN_BIAS, kOpFormat_NDC1HWC0,
|
||||
kOpFormat_NCDHW, kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC,
|
||||
kOpFormat_DHWCN};
|
||||
|
||||
auto iter = kOpFormatList.find(format);
|
||||
return iter != kOpFormatList.end();
|
||||
}
|
||||
|
||||
bool IsOneOfServerFormatC04(const std::string &format) {
|
||||
const std::set<std::string> kServerFormatC04List = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
|
||||
return kServerFormatC04List.find(format) != kServerFormatC04List.end();
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,7 +26,6 @@ realdiv_op_info = TBERegOp("RealDiv") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
|
|
|
@ -502,6 +502,7 @@ class TBERegOp(RegOp):
|
|||
self.kernel_name_ = ''
|
||||
self.partial_flag_ = False
|
||||
self.reshape_type_ = ''
|
||||
self.dynamic_rank_support_ = False
|
||||
self.dynamic_shape_ = False
|
||||
self.dynamic_compile_static_ = False
|
||||
self.need_check_supported_ = False
|
||||
|
@ -509,6 +510,29 @@ class TBERegOp(RegOp):
|
|||
self.op_pattern_ = ""
|
||||
self.real_input_index_ = []
|
||||
self.input_to_attr_index_ = []
|
||||
self.unknown_shape_formats_ = []
|
||||
|
||||
def unknown_shape_formats(self, unknown_shape_formats):
|
||||
"""
|
||||
Description operator front end and tbe operator input mapping.
|
||||
|
||||
Args:
|
||||
unknown_shape_formats (list): Value of unknown_shape_formats. Default: ().
|
||||
"""
|
||||
RegOp._is_list(unknown_shape_formats)
|
||||
self.unknown_shape_formats_.append(unknown_shape_formats)
|
||||
return self
|
||||
|
||||
def dynamic_rank_support(self, dynamic_rank_support):
|
||||
"""
|
||||
Description operator front end and tbe operator input mapping.
|
||||
|
||||
Args:
|
||||
dynamic_rank_support (bool): Value of dynamic_rank_support. Default: false.
|
||||
"""
|
||||
self._is_bool(dynamic_rank_support)
|
||||
self.dynamic_rank_support_ = dynamic_rank_support
|
||||
return self
|
||||
|
||||
def real_input_index(self, real_input_index):
|
||||
"""
|
||||
|
@ -1146,3 +1170,49 @@ class DataType:
|
|||
|
||||
C64_Default = ("complex64", "DefaultFormat")
|
||||
C128_Default = ("complex128", "DefaultFormat")
|
||||
|
||||
|
||||
class Format:
|
||||
r"""
|
||||
Various combinations of unknown shape format of ascend ops.
|
||||
|
||||
current support:
|
||||
|
||||
.. code-block::
|
||||
C1HWNCoC0 = "C1HWNCoC0"
|
||||
ChannelLast = "ChannelLast"
|
||||
Default = "DefaultFormat"
|
||||
DHWCN = "DHWCN"
|
||||
FHD = "NC1HWC0"
|
||||
FracNZ = "FRACTAL_NZ"
|
||||
FRACTAL_Z_3D = "FRACTAL_Z_3D"
|
||||
FracZ = "FRACTAL_Z"
|
||||
FracZNLSTM = "FRACTAL_ZN_LSTM"
|
||||
FracZNRNN = "FRACTAL_ZN_RNN"
|
||||
HWCN = "HWCN"
|
||||
NCDHW = "NCDHW"
|
||||
NCHW = "NCHW"
|
||||
ND_RNNBIAS = "ND_RNN_BIAS"
|
||||
NDC1HWC0 = "NDC1HWC0"
|
||||
NDHWC = "NDHWC"
|
||||
NHWC = "NHWC"
|
||||
NULL = "NULL"
|
||||
"""
|
||||
C1HWNCoC0 = "C1HWNCoC0"
|
||||
ChannelLast = "ChannelLast"
|
||||
Default = "DefaultFormat"
|
||||
DHWCN = "DHWCN"
|
||||
FHD = "NC1HWC0"
|
||||
FracNZ = "FRACTAL_NZ"
|
||||
FRACTAL_Z_3D = "FRACTAL_Z_3D"
|
||||
FracZ = "FRACTAL_Z"
|
||||
FracZNLSTM = "FRACTAL_ZN_LSTM"
|
||||
FracZNRNN = "FRACTAL_ZN_RNN"
|
||||
HWCN = "HWCN"
|
||||
NCDHW = "NCDHW"
|
||||
NCHW = "NCHW"
|
||||
ND_RNNBIAS = "ND_RNN_BIAS"
|
||||
NDC1HWC0 = "NDC1HWC0"
|
||||
NDHWC = "NDHWC"
|
||||
NHWC = "NHWC"
|
||||
NULL = "NULL"
|
||||
|
|
|
@ -378,6 +378,6 @@ def test_train_feed(num_classes=65536):
|
|||
model = Model(net, loss_fn=loss, optimizer=opt)
|
||||
model.train(3, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
|
||||
loss_value = np.array(parallel_callback.loss_list)
|
||||
expect_out = [11.344119, 10.747661, 11.134097]
|
||||
expect_out = [11.344119, 10.747664, 11.132818]
|
||||
print(loss_value)
|
||||
assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue