Arel: make `Or` nodes "Nary" like `And`

Fix: https://github.com/rails/rails/issues/51386

This significantly reduce the depth of the tree for large `OR`
conditions. I was initially a bit on the fence about that fix,
but given that `And` is already implemented this way, I see no
reasons not to do the same.

Amusingly, the reported repro script now makes SQLite fail:

```ruby
SQLite3::SQLException: Expression tree is too large (maximum depth 1000)
```
This commit is contained in:
Jean Boussier 2024-04-04 14:59:56 +02:00
parent d4c40b6cf9
commit bfcc13ab7c
15 changed files with 37 additions and 36 deletions

View File

@ -142,7 +142,7 @@ module ActiveRecord
queries.first
else
queries.map! { |query| query.reduce(&:and) }
queries = queries.reduce { |result, query| Arel::Nodes::Or.new(result, query) }
queries = queries.reduce { |result, query| Arel::Nodes::Or.new([result, query]) }
Arel::Nodes::Grouping.new(queries)
end
end

View File

@ -1098,7 +1098,7 @@ module ActiveRecord
raise ArgumentError, "Relation passed to #or must be structurally compatible. Incompatible values: #{incompatible_values}"
end
self.where_clause = self.where_clause.or(other.where_clause)
self.where_clause = where_clause.or(other.where_clause)
self.having_clause = having_clause.or(other.having_clause)
self.references_values |= other.references_values

View File

@ -47,7 +47,11 @@ module ActiveRecord
right = right.ast
right = right.expr if right.is_a?(Arel::Nodes::Grouping)
or_clause = Arel::Nodes::Or.new(left, right)
or_clause = if left.is_a?(Arel::Nodes::Or)
Arel::Nodes::Or.new(left.children + [right])
else
Arel::Nodes::Or.new([left, right])
end
common.predicates << Arel::Nodes::Grouping.new(or_clause)
common

View File

@ -41,8 +41,8 @@ require "arel/nodes/matches"
require "arel/nodes/regexp"
require "arel/nodes/cte"
# nary
require "arel/nodes/and"
# nary (And and Or)
require "arel/nodes/nary"
# function
# FIXME: Function + Alias can be rewritten as a Function and Alias node.

View File

@ -111,12 +111,6 @@ module Arel # :nodoc: all
end
end
class Or < Binary
def fetch_attribute(&block)
left.fetch_attribute(&block) && right.fetch_attribute(&block)
end
end
%w{
Assignment
Join

View File

@ -2,7 +2,7 @@
module Arel # :nodoc: all
module Nodes
class And < Arel::Nodes::NodeExpression
class Nary < Arel::Nodes::NodeExpression
attr_reader :children
def initialize(children)
@ -23,7 +23,7 @@ module Arel # :nodoc: all
end
def hash
children.hash
[self.class, children].hash
end
def eql?(other)
@ -32,5 +32,8 @@ module Arel # :nodoc: all
end
alias :== :eql?
end
And = Class.new(Nary)
Or = Class.new(Nary)
end
end

View File

@ -127,7 +127,7 @@ module Arel # :nodoc: all
# Factory method to create a Nodes::Grouping node that has an Nodes::Or
# node as a child.
def or(right)
Nodes::Grouping.new Nodes::Or.new(self, right)
Nodes::Grouping.new Nodes::Or.new([self, right])
end
###

View File

@ -232,7 +232,7 @@ module Arel # :nodoc: all
def grouping_any(method_id, others, *extras)
nodes = others.map { |expr| send(method_id, expr, *extras) }
Nodes::Grouping.new nodes.inject { |memo, node|
Nodes::Or.new(memo, node)
Nodes::Or.new([memo, node])
}
end

View File

@ -191,6 +191,7 @@ module Arel # :nodoc: all
end
end
alias :visit_Arel_Nodes_And :visit__children
alias :visit_Arel_Nodes_Or :visit__children
alias :visit_Arel_Nodes_With :visit__children
def visit_String(o)

View File

@ -622,18 +622,7 @@ module Arel # :nodoc: all
end
def visit_Arel_Nodes_Or(o, collector)
stack = [o.right, o.left]
while o = stack.pop
if o.is_a?(Arel::Nodes::Or)
stack.push o.right, o.left
else
visit o, collector
collector << " OR " unless stack.empty?
end
end
collector
inject_join o.children, collector, " OR "
end
def visit_Arel_Nodes_Assignment(o, collector)

View File

@ -804,7 +804,7 @@ module Arel
node = attribute.not_between(1..3)
_(node).must_equal Nodes::Grouping.new(
Nodes::Or.new(
Nodes::Or.new([
Nodes::LessThan.new(
attribute,
Nodes::Casted.new(1, attribute)
@ -813,7 +813,7 @@ module Arel
attribute,
Nodes::Casted.new(3, attribute)
)
)
])
)
end
@ -930,7 +930,7 @@ module Arel
node = attribute.not_between(0...3)
_(node).must_equal Nodes::Grouping.new(
Nodes::Or.new(
Nodes::Or.new([
Nodes::LessThan.new(
attribute,
Nodes::Casted.new(0, attribute)
@ -939,7 +939,7 @@ module Arel
attribute,
Nodes::Casted.new(3, attribute)
)
)
])
)
end
end

View File

@ -22,12 +22,12 @@ module Arel
describe "equality" do
it "is equal with equal ivars" do
array = [Or.new("foo", "bar"), Or.new("foo", "bar")]
array = [Or.new(["foo", "bar"]), Or.new(["foo", "bar"])]
assert_equal 1, array.uniq.size
end
it "is not equal with different ivars" do
array = [Or.new("foo", "bar"), Or.new("foo", "baz")]
array = [Or.new(["foo", "bar"]), Or.new(["foo", "baz"])]
assert_equal 2, array.uniq.size
end
end

View File

@ -65,7 +65,6 @@ module Arel
Arel::Nodes::Matches,
Arel::Nodes::NotEqual,
Arel::Nodes::NotIn,
Arel::Nodes::Or,
Arel::Nodes::TableAlias,
Arel::Nodes::As,
Arel::Nodes::JoinSource,
@ -77,6 +76,17 @@ module Arel
end
end
# nary ops
[
Arel::Nodes::And,
Arel::Nodes::Or,
].each do |klass|
define_method("test_#{klass.name.gsub('::', '_')}") do
binary = klass.new([:a, :b])
@visitor.accept binary, Collectors::PlainString.new
end
end
def test_Arel_Nodes_BindParam
node = Arel::Nodes::BindParam.new(1)
collector = Collectors::PlainString.new

View File

@ -361,7 +361,7 @@ module Arel
end
it "should visit_Arel_Nodes_Or" do
node = Nodes::Or.new @attr.eq(10), @attr.eq(11)
node = Nodes::Or.new [@attr.eq(10), @attr.eq(11)]
_(compile(node)).must_be_like %{
"users"."id" = 10 OR "users"."id" = 11
}

View File

@ -185,7 +185,7 @@ class ActiveRecord::Relation
other_clause = WhereClause.new([table["name"].eq(bind_param("Sean"))])
expected_ast =
Arel::Nodes::Grouping.new(
Arel::Nodes::Or.new(table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean")))
Arel::Nodes::Or.new([table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean"))])
)
assert_equal expected_ast.to_sql, where_clause.or(other_clause).ast.to_sql