forked from mindspore-Ecosystem/mindspore
!14737 add is dyanmic format op info
From: @jjfeing Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
181c4d31c0
|
@ -53,7 +53,6 @@ enum OpPattern {
|
|||
kFormatAgnosticPattern = 1,
|
||||
kBroadcastPattern = 2,
|
||||
kReducePattern = 3,
|
||||
kDynamicFormatPattern = 4,
|
||||
};
|
||||
|
||||
// Backend processor
|
||||
|
|
|
@ -101,11 +101,11 @@ class OpInfo {
|
|||
compute_cost_ = opinfo.compute_cost_;
|
||||
kernel_name_ = opinfo.kernel_name();
|
||||
partial_flag_ = opinfo.partial_flag_;
|
||||
dynamic_format_ = opinfo.dynamic_format_;
|
||||
dynamic_shape_ = opinfo.dynamic_shape_;
|
||||
op_pattern_ = opinfo.op_pattern();
|
||||
processor_ = opinfo.processor_;
|
||||
need_check_supported_ = opinfo.need_check_supported();
|
||||
is_dynamic_format_ = opinfo.is_dynamic_format();
|
||||
for (const auto &attr : opinfo.attrs_ptr()) {
|
||||
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
|
||||
}
|
||||
|
@ -127,6 +127,7 @@ class OpInfo {
|
|||
bool dynamic_shape() const { return dynamic_shape_; }
|
||||
std::string processor() const { return processor_; }
|
||||
bool need_check_supported() const { return need_check_supported_; }
|
||||
bool is_dynamic_format() const { return is_dynamic_format_; }
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
||||
|
@ -144,7 +145,8 @@ class OpInfo {
|
|||
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
|
||||
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
|
||||
void set_processor(const std::string &processor) { processor_ = processor; }
|
||||
void set_need_check_supported(const bool need_check_supported) { need_check_supported_ = need_check_supported; }
|
||||
void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; }
|
||||
void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; }
|
||||
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
|
||||
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
|
||||
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
|
||||
|
@ -169,9 +171,9 @@ class OpInfo {
|
|||
int compute_cost_ = 0;
|
||||
std::string kernel_name_;
|
||||
bool partial_flag_ = false;
|
||||
bool dynamic_format_ = false;
|
||||
bool dynamic_shape_ = false;
|
||||
bool need_check_supported_ = false;
|
||||
bool is_dynamic_format_ = false;
|
||||
OpPattern op_pattern_ = kCommonPattern;
|
||||
std::string processor_;
|
||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
|
||||
|
|
|
@ -34,6 +34,7 @@ constexpr auto kKernelName = "kernel_name";
|
|||
constexpr auto kPartialFlag = "partial_flag";
|
||||
constexpr auto kReshapeType = "reshape_type";
|
||||
constexpr auto kOpPattern = "op_pattern";
|
||||
constexpr auto kIsDynamicFormat = "is_dynamic_format";
|
||||
constexpr auto kDynamicFormat = "dynamicFormat";
|
||||
constexpr auto kFormatAgnostic = "formatAgnostic";
|
||||
constexpr auto kNeedCheckSupported = "need_check_supported";
|
||||
|
@ -102,10 +103,8 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
|
|||
}
|
||||
|
||||
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
|
||||
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
|
||||
{kBroadcast, kBroadcastPattern},
|
||||
{kReduce, kReducePattern},
|
||||
{kDynamicFormat, kDynamicFormatPattern}};
|
||||
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {
|
||||
{kFormatAgnostic, kFormatAgnosticPattern}, {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}};
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
op_info->set_async_flag(obj.at(kAsyncFlag));
|
||||
op_info->set_binfile_name(obj.at(kBinfileName));
|
||||
|
@ -118,6 +117,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
|||
op_info->set_dynamic_shape(obj.at(kDynamicShape));
|
||||
}
|
||||
|
||||
if (obj.find(kIsDynamicFormat) != obj.end()) {
|
||||
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
|
||||
}
|
||||
|
||||
if (obj.find(kOpPattern) != obj.end()) {
|
||||
std::string op_pattern = obj.at(kOpPattern);
|
||||
auto find_iter = kOpPatternMap.find(op_pattern);
|
||||
|
|
|
@ -63,11 +63,13 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
|
|||
MS_LOG(INFO) << "Warning: node(" << cnode_ptr_->fullname_with_scope() << ") not support tbe aicore.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (op_info_ptr->is_dynamic_format()) {
|
||||
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
|
||||
} else {
|
||||
OpPattern pattern = op_info_ptr->op_pattern();
|
||||
if (pattern == kCommonPattern) {
|
||||
GetCommonPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kDynamicFormatPattern) {
|
||||
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kFormatAgnosticPattern) {
|
||||
GetAgnosticPatternKernelInfo(*op_info_ptr);
|
||||
} else if (pattern == kBroadcastPattern) {
|
||||
|
@ -77,6 +79,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
|
|||
} else {
|
||||
MS_LOG(INFO) << "Warning: op pattern is invailed.";
|
||||
}
|
||||
}
|
||||
// check support
|
||||
FilterInVaildKernelInfo(*op_info_ptr);
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \
|
|||
.output(0, "output0", False, "required", "all") \
|
||||
.output(1, "output1", False, "required", "all") \
|
||||
.output(2, "output2", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
|
||||
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
|
||||
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
|
||||
|
|
|
@ -30,7 +30,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
|
|||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "bias", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
|
|
|
@ -23,7 +23,7 @@ bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("sigmoid_cross_entropy_with_logits_v2") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("reduction", "optional", "str", "all", "mean") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "target", False, "required", "all") \
|
||||
|
|
|
@ -27,7 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "bias", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \
|
|||
.input(1, "y", False, "required", "all") \
|
||||
.input(2, "weight", False, "optional", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
|
|||
.input(0, "x", False, "required", "all", reshape_type="NC") \
|
||||
.output(0, "sum", False, "required", "all") \
|
||||
.output(1, "square_sum", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -32,7 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
|
|||
.input(5, "batch_mean", False, "required", "all") \
|
||||
.input(6, "batch_variance", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
|
|
|
@ -30,7 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
|
|||
.input(3, "batch_variance", False, "required", "all") \
|
||||
.output(0, "diff_scale", False, "required", "all") \
|
||||
.output(1, "diff_offset", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F32_None, DataType.F32_None,
|
||||
DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
|
|
|
@ -32,7 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
|
|||
.output(0, "y", False, "required", "all", reshape_type="NC") \
|
||||
.output(1, "batch_mean", False, "required", "all") \
|
||||
.output(2, "batch_variance", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None,
|
||||
DataType.F32_None, DataType.F32_None, DataType.F16_None,
|
||||
DataType.F32_None, DataType.F32_None) \
|
||||
|
|
|
@ -26,7 +26,7 @@ concat_op_info = TBERegOp("Concat") \
|
|||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "input_values", False, "dynamic", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ concat_ds_op_info = TBERegOp("Concat") \
|
|||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "input_values", False, "dynamic", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
|
|||
.attr("transpose_first", "required", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("conv2d") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
|
@ -35,7 +35,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
|
|||
.input(2, "bias", False, "optional", "all") \
|
||||
.input(3, "offset_w", False, "optional", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I32_None, DataType.I8_None, DataType.I32_None) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -27,7 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
|
|||
.input(1, "mask", False, "required", "all") \
|
||||
.input(2, "keep_prob", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
|
|||
.input(1, "x2", False, "required", "all") \
|
||||
.input(2, "x3", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ l2_loss_op_info = TBERegOp("L2Loss") \
|
|||
.partial_flag(True) \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
|
|||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "mean", False, "required", "all") \
|
||||
.output(2, "variance", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||
DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
|
|
|
@ -30,7 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
|
|||
.input(3, "mean", False, "required", "all") \
|
||||
.output(0, "pd_gamma", False, "required", "all") \
|
||||
.output(1, "pd_beta", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||
DataType.F32_None, DataType.F32_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
|
|
|
@ -29,7 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
|
|||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "gamma", False, "required", "all") \
|
||||
.output(0, "pd_x", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
|
||||
DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
|
||||
|
|
|
@ -33,7 +33,7 @@ max_pool3d_op_info = TBERegOp("MaxPool3D") \
|
|||
.attr("format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ mul_op_info = TBERegOp("Mul") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ mul_ds_op_info = TBERegOp("Mul") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ prelu_op_info = TBERegOp("PReLU") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "weight", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \
|
|||
.input(2, "weights", False, "required", "all") \
|
||||
.output(0, "dx", False, "required", "all") \
|
||||
.output(1, "da", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
|
||||
DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -27,6 +27,7 @@ realdiv_op_info = TBERegOp("RealDiv") \
|
|||
.input(1, "y", False, "required", "all") \
|
||||
.output(0, "z", False, "required", "all") \
|
||||
.op_pattern("broadcast") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -25,7 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
|
|||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ reverse_v2_d_op_info = TBERegOp("ReverseV2") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("reverse_v2_d") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
|
|
|
@ -27,7 +27,7 @@ select_op_info = TBERegOp("Select") \
|
|||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ slice_op_info = TBERegOp("Slice") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("slice_d") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("begin", "required", "listInt", "all") \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
|
|
|
@ -26,7 +26,7 @@ softmax_op_info = TBERegOp("Softmax") \
|
|||
.attr("axis", "optional", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -23,14 +23,14 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("softmax_grad_ext") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_format(True) \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("axis", "required", "listInt", "all") \
|
||||
.attr("keepdims", "required", "bool", "all") \
|
||||
.input(0, "grad", False, "required", "all") \
|
||||
.input(1, "x1", False, "required", "all") \
|
||||
.input(2, "x2", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None,
|
||||
DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -27,7 +27,7 @@ split_d_op_info = TBERegOp("Split") \
|
|||
.attr("output_num", "required", "int", "all") \
|
||||
.input(0, "value", False, "required", "all") \
|
||||
.output(0, "output", False, "dynamic", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ split_v_op_info = TBERegOp("SplitV") \
|
|||
.attr("num_split", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "dynamic", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ strided_slice_d_op_info = TBERegOp("StridedSlice") \
|
|||
.attr("shrink_axis_mask", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ tensor_add_op_info = TBERegOp("Add") \
|
|||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
|
|
|
@ -27,7 +27,7 @@ tensor_add_op_info = TBERegOp("Add") \
|
|||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
|
|
|
@ -26,7 +26,7 @@ tile_op_info = TBERegOp("Tile") \
|
|||
.attr("multiples", "optional", "listInt", "all")\
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ unstack_op_info = TBERegOp("Unstack") \
|
|||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "dynamic", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
|
|||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "segment_ids", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
|
|
@ -353,9 +353,9 @@ class TBERegOp(RegOp):
|
|||
self.kernel_name_ = ''
|
||||
self.partial_flag_ = False
|
||||
self.reshape_type_ = ''
|
||||
self.dynamic_format_ = False
|
||||
self.dynamic_shape_ = False
|
||||
self.need_check_supported_ = False
|
||||
self.is_dynamic_format_ = False
|
||||
self.op_pattern_ = ""
|
||||
|
||||
def async_flag(self, async_flag):
|
||||
|
@ -424,17 +424,6 @@ class TBERegOp(RegOp):
|
|||
self.reshape_type_ = reshape_type
|
||||
return self
|
||||
|
||||
def dynamic_format(self, dynamic_format):
|
||||
"""
|
||||
Whether the operator supports dynamic selection of format and dtype or not.
|
||||
|
||||
Args:
|
||||
dynamic_format (bool): Value of dynamic format. Default: false.
|
||||
"""
|
||||
self._is_bool(dynamic_format)
|
||||
self.dynamic_format_ = dynamic_format
|
||||
return self
|
||||
|
||||
def dynamic_shape(self, dynamic_shape):
|
||||
"""
|
||||
Whether the operator supports dynamic shape.
|
||||
|
@ -451,12 +440,23 @@ class TBERegOp(RegOp):
|
|||
Whether the operator need check supports.
|
||||
|
||||
Args:
|
||||
:param need_check_supported: (bool): Value of need_check_supported. Default: false.
|
||||
need_check_supported (bool): Value of need_check_supported. Default: false.
|
||||
"""
|
||||
self._is_bool(need_check_supported)
|
||||
self.need_check_supported_ = need_check_supported
|
||||
return self
|
||||
|
||||
def is_dynamic_format(self, is_dynamic_format):
|
||||
"""
|
||||
Whether the operator need calop_select_format api.
|
||||
|
||||
Args:
|
||||
is_dynamic_format (bool): Value of is_dynamic_format_. Default: false.
|
||||
"""
|
||||
self._is_bool(is_dynamic_format)
|
||||
self.is_dynamic_format_ = is_dynamic_format
|
||||
return self
|
||||
|
||||
def op_pattern(self, pattern=None):
|
||||
"""
|
||||
The behavior type of operator, such as broadcast, reduce and so on.
|
||||
|
|
Loading…
Reference in New Issue