!14737 add is dyanmic format op info

From: @jjfeing
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-04-23 15:03:59 +08:00 committed by Gitee
commit 181c4d31c0
44 changed files with 80 additions and 72 deletions

View File

@ -53,7 +53,6 @@ enum OpPattern {
kFormatAgnosticPattern = 1,
kBroadcastPattern = 2,
kReducePattern = 3,
kDynamicFormatPattern = 4,
};
// Backend processor

View File

@ -101,11 +101,11 @@ class OpInfo {
compute_cost_ = opinfo.compute_cost_;
kernel_name_ = opinfo.kernel_name();
partial_flag_ = opinfo.partial_flag_;
dynamic_format_ = opinfo.dynamic_format_;
dynamic_shape_ = opinfo.dynamic_shape_;
op_pattern_ = opinfo.op_pattern();
processor_ = opinfo.processor_;
need_check_supported_ = opinfo.need_check_supported();
is_dynamic_format_ = opinfo.is_dynamic_format();
for (const auto &attr : opinfo.attrs_ptr()) {
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
}
@ -127,6 +127,7 @@ class OpInfo {
bool dynamic_shape() const { return dynamic_shape_; }
std::string processor() const { return processor_; }
bool need_check_supported() const { return need_check_supported_; }
bool is_dynamic_format() const { return is_dynamic_format_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
@ -144,7 +145,8 @@ class OpInfo {
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void set_processor(const std::string &processor) { processor_ = processor; }
void set_need_check_supported(const bool need_check_supported) { need_check_supported_ = need_check_supported; }
void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; }
void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
@ -169,9 +171,9 @@ class OpInfo {
int compute_cost_ = 0;
std::string kernel_name_;
bool partial_flag_ = false;
bool dynamic_format_ = false;
bool dynamic_shape_ = false;
bool need_check_supported_ = false;
bool is_dynamic_format_ = false;
OpPattern op_pattern_ = kCommonPattern;
std::string processor_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;

View File

@ -34,6 +34,7 @@ constexpr auto kKernelName = "kernel_name";
constexpr auto kPartialFlag = "partial_flag";
constexpr auto kReshapeType = "reshape_type";
constexpr auto kOpPattern = "op_pattern";
constexpr auto kIsDynamicFormat = "is_dynamic_format";
constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kNeedCheckSupported = "need_check_supported";
@ -102,10 +103,8 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
}
void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern},
{kBroadcast, kBroadcastPattern},
{kReduce, kReducePattern},
{kDynamicFormat, kDynamicFormatPattern}};
const std::map<std::string, kernel::OpPattern> kOpPatternMap = {
{kFormatAgnostic, kFormatAgnosticPattern}, {kBroadcast, kBroadcastPattern}, {kReduce, kReducePattern}};
MS_EXCEPTION_IF_NULL(op_info);
op_info->set_async_flag(obj.at(kAsyncFlag));
op_info->set_binfile_name(obj.at(kBinfileName));
@ -118,6 +117,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_dynamic_shape(obj.at(kDynamicShape));
}
if (obj.find(kIsDynamicFormat) != obj.end()) {
op_info->set_is_dynamic_format(obj.at(kIsDynamicFormat));
}
if (obj.find(kOpPattern) != obj.end()) {
std::string op_pattern = obj.at(kOpPattern);
auto find_iter = kOpPatternMap.find(op_pattern);

View File

@ -63,11 +63,13 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
MS_LOG(INFO) << "Warning: node(" << cnode_ptr_->fullname_with_scope() << ") not support tbe aicore.";
return;
}
if (op_info_ptr->is_dynamic_format()) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
} else {
OpPattern pattern = op_info_ptr->op_pattern();
if (pattern == kCommonPattern) {
GetCommonPatternKernelInfo(*op_info_ptr);
} else if (pattern == kDynamicFormatPattern) {
GetDynamicFormatPatternKernelInfo(*op_info_ptr);
} else if (pattern == kFormatAgnosticPattern) {
GetAgnosticPatternKernelInfo(*op_info_ptr);
} else if (pattern == kBroadcastPattern) {
@ -77,6 +79,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
} else {
MS_LOG(INFO) << "Warning: op pattern is invailed.";
}
}
// check support
FilterInVaildKernelInfo(*op_info_ptr);
}

View File

@ -36,7 +36,7 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \
.output(0, "output0", False, "required", "all") \
.output(1, "output1", False, "required", "all") \
.output(2, "output2", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,

View File

@ -30,7 +30,7 @@ batch_matmul_op_info = TBERegOp("BatchMatMul") \
.input(1, "x2", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \

View File

@ -23,7 +23,7 @@ bce_with_logits_loss_op_info = TBERegOp("BCEWithLogitsLoss") \
.compute_cost(10) \
.kernel_name("sigmoid_cross_entropy_with_logits_v2") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.attr("reduction", "optional", "str", "all", "mean") \
.input(0, "predict", False, "required", "all") \
.input(1, "target", False, "required", "all") \

View File

@ -27,7 +27,7 @@ bias_add_grad_op_info = TBERegOp("BiasAdd") \
.input(0, "x", False, "required", "all") \
.input(1, "bias", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -28,7 +28,7 @@ binary_cross_entropy_op_info = TBERegOp("BinaryCrossEntropy") \
.input(1, "y", False, "required", "all") \
.input(2, "weight", False, "optional", "all") \
.output(0, "output", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -26,7 +26,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()

View File

@ -32,7 +32,7 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,
DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,

View File

@ -30,7 +30,7 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \
.output(1, "diff_offset", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F32_None, DataType.F32_None,
DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,

View File

@ -32,7 +32,7 @@ bn_training_update_v2_op_info = TBERegOp("BNTrainingUpdateV2") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.output(1, "batch_mean", False, "required", "all") \
.output(2, "batch_variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F32_None, DataType.F32_None,
DataType.F32_None, DataType.F32_None, DataType.F16_None,
DataType.F32_None, DataType.F32_None) \

View File

@ -26,7 +26,7 @@ concat_op_info = TBERegOp("Concat") \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ concat_ds_op_info = TBERegOp("Concat") \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -28,7 +28,7 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.attr("transpose_first", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -23,7 +23,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.compute_cost(10) \
.kernel_name("conv2d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
@ -35,7 +35,7 @@ conv2d_op_info = TBERegOp("Conv2D") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.I8_None, DataType.F16_None) \
.dtype_format(DataType.I8_None, DataType.I8_None, DataType.I32_None, DataType.I8_None, DataType.I32_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ drop_out_do_mask_op_info = TBERegOp("DropoutDoMask") \
.input(1, "mask", False, "required", "all") \
.input(2, "keep_prob", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ fused_mul_add_op_info = TBERegOp("FusedMulAdd") \
.input(1, "x2", False, "required", "all") \
.input(2, "x3", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -25,7 +25,7 @@ l2_loss_op_info = TBERegOp("L2Loss") \
.partial_flag(True) \
.input(0, "x", None, "required", None) \
.output(0, "y", True, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -32,7 +32,7 @@ layer_norm_op_info = TBERegOp("LayerNorm") \
.output(0, "y", False, "required", "all") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,

View File

@ -30,7 +30,7 @@ layer_norm_beta_gamma_backprop_op_info = TBERegOp("LayerNormBetaGammaBackprop")
.input(3, "mean", False, "required", "all") \
.output(0, "pd_gamma", False, "required", "all") \
.output(1, "pd_beta", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,

View File

@ -29,7 +29,7 @@ layer_norm_x_backprop_op_info = TBERegOp("LayerNormXBackprop") \
.input(3, "mean", False, "required", "all") \
.input(4, "gamma", False, "required", "all") \
.output(0, "pd_x", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None,
DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None,

View File

@ -33,7 +33,7 @@ max_pool3d_op_info = TBERegOp("MaxPool3D") \
.attr("format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -26,7 +26,7 @@ mul_op_info = TBERegOp("Mul") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ mul_ds_op_info = TBERegOp("Mul") \
.input(0, "x", False, "required", "all") \
.input(1, "y", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -26,7 +26,7 @@ prelu_op_info = TBERegOp("PReLU") \
.input(0, "x", False, "required", "all") \
.input(1, "weight", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -28,7 +28,7 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \
.input(2, "weights", False, "required", "all") \
.output(0, "dx", False, "required", "all") \
.output(1, "da", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None) \
.get_op_info()

View File

@ -27,6 +27,7 @@ realdiv_op_info = TBERegOp("RealDiv") \
.input(1, "y", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.op_pattern("broadcast") \
.is_dynamic_format(True) \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.get_op_info()

View File

@ -25,7 +25,7 @@ reciprocal_op_info = TBERegOp("Reciprocal") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -23,7 +23,7 @@ reverse_v2_d_op_info = TBERegOp("ReverseV2") \
.compute_cost(10) \
.kernel_name("reverse_v2_d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.attr("axis", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \

View File

@ -27,7 +27,7 @@ select_op_info = TBERegOp("Select") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -23,7 +23,7 @@ slice_op_info = TBERegOp("Slice") \
.compute_cost(10) \
.kernel_name("slice_d") \
.partial_flag(True) \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.attr("begin", "required", "listInt", "all") \
.attr("size", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \

View File

@ -26,7 +26,7 @@ softmax_op_info = TBERegOp("Softmax") \
.attr("axis", "optional", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -23,14 +23,14 @@ softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \
.compute_cost(10) \
.kernel_name("softmax_grad_ext") \
.partial_flag(True) \
.dynamic_format(True) \
.is_dynamic_format(True) \
.attr("axis", "required", "listInt", "all") \
.attr("keepdims", "required", "bool", "all") \
.input(0, "grad", False, "required", "all") \
.input(1, "x1", False, "required", "all") \
.input(2, "x2", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ split_d_op_info = TBERegOp("Split") \
.attr("output_num", "required", "int", "all") \
.input(0, "value", False, "required", "all") \
.output(0, "output", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -28,7 +28,7 @@ split_v_op_info = TBERegOp("SplitV") \
.attr("num_split", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -33,7 +33,7 @@ strided_slice_d_op_info = TBERegOp("StridedSlice") \
.attr("shrink_axis_mask", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -26,7 +26,7 @@ tensor_add_op_info = TBERegOp("Add") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \

View File

@ -27,7 +27,7 @@ tensor_add_op_info = TBERegOp("Add") \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \

View File

@ -26,7 +26,7 @@ tile_op_info = TBERegOp("Tile") \
.attr("multiples", "optional", "listInt", "all")\
.input(0, "x1", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ unstack_op_info = TBERegOp("Unstack") \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -27,7 +27,7 @@ unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
.input(0, "x", False, "required", "all") \
.input(1, "segment_ids", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("dynamicFormat") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()

View File

@ -353,9 +353,9 @@ class TBERegOp(RegOp):
self.kernel_name_ = ''
self.partial_flag_ = False
self.reshape_type_ = ''
self.dynamic_format_ = False
self.dynamic_shape_ = False
self.need_check_supported_ = False
self.is_dynamic_format_ = False
self.op_pattern_ = ""
def async_flag(self, async_flag):
@ -424,17 +424,6 @@ class TBERegOp(RegOp):
self.reshape_type_ = reshape_type
return self
def dynamic_format(self, dynamic_format):
"""
Whether the operator supports dynamic selection of format and dtype or not.
Args:
dynamic_format (bool): Value of dynamic format. Default: false.
"""
self._is_bool(dynamic_format)
self.dynamic_format_ = dynamic_format
return self
def dynamic_shape(self, dynamic_shape):
"""
Whether the operator supports dynamic shape.
@ -451,12 +440,23 @@ class TBERegOp(RegOp):
Whether the operator need check supports.
Args:
:param need_check_supported: (bool): Value of need_check_supported. Default: false.
need_check_supported (bool): Value of need_check_supported. Default: false.
"""
self._is_bool(need_check_supported)
self.need_check_supported_ = need_check_supported
return self
def is_dynamic_format(self, is_dynamic_format):
"""
Whether the operator need calop_select_format api.
Args:
is_dynamic_format (bool): Value of is_dynamic_format_. Default: false.
"""
self._is_bool(is_dynamic_format)
self.is_dynamic_format_ = is_dynamic_format
return self
def op_pattern(self, pattern=None):
"""
The behavior type of operator, such as broadcast, reduce and so on.