Change impl::printBinaryOp() to consider operand and result type

The operand and result types of binary ops are not necessarily the
same. For those binary ops, we cannot print in the short-form assembly.

Enhance impl:::printBinaryOp to consider operand and result types
to select which assembly form to use.

PiperOrigin-RevId: 229608142
This commit is contained in:
Lei Zhang 2019-01-16 12:49:11 -08:00 committed by jpienaar
parent 5843e5a7c0
commit 3766332533
3 changed files with 17 additions and 2 deletions

View File

@ -895,6 +895,8 @@ namespace impl {
void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs);
bool parseBinaryOp(OpAsmParser *parser, OperationState *result);
// Prints the given binary `op` in short-hand notion if both the two operands
// and the result have the same time. Otherwise, prints the long-hand notion.
void printBinaryOp(const OperationInst *op, OpAsmPrinter *p);
} // namespace impl

View File

@ -345,9 +345,22 @@ bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
}
void impl::printBinaryOp(const OperationInst *op, OpAsmPrinter *p) {
assert(op->getNumOperands() == 2 && "binary op should have two operands");
assert(op->getNumResults() == 1 && "binary op should have one result");
// If not all the operand and result types are the same, just use the
// canonical form to avoid omitting information in printing.
auto resultType = op->getResult(0)->getType();
if (op->getOperand(0)->getType() != resultType ||
op->getOperand(1)->getType() != resultType) {
p->printDefaultOp(op);
return;
}
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
// Now we can output only one type for all operands and the result.
*p << " : " << op->getResult(0)->getType();
}

View File

@ -5,8 +5,8 @@
// CHECK-LABEL: @broadcast_scalar_scalar_scalar
func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
// CHECK: %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function: "RELU6"} : tensor<i32>
%0 = tfl.add %arg0, %arg1 {fused_activation_function: "RELU6"} : tensor<i32>
return %0 : tensor<i32>
}