Refactor Active Record to let Arel manage bind params

A common source of bugs and code bloat within Active Record has been the
need for us to maintain the list of bind values separately from the AST
they're associated with. This makes any sort of AST manipulation
incredibly difficult, as any time we want to potentially insert or
remove an AST node, we need to traverse the entire tree to find where
the associated bind parameters are.

With this change, the bind parameters now live on the AST directly.
Active Record does not need to know or care about them until the final
AST traversal for SQL construction. Rather than returning just the SQL,
the Arel collector will now return both the SQL and the bind parameters.
At this point the connection adapter will have all the values that it
had before.

A bit of this code is janky and something I'd like to refactor later. In
particular, I don't like how we're handling associations in the
predicate builder, the special casing of `StatementCache::Substitute` in
`QueryAttribute`, or generally how we're handling bind value replacement
in the statement cache when prepared statements are disabled.

This also mostly reverts #26378, as it moved all the code into a
location that I wanted to delete.

/cc @metaskills @yahonda, this change will affect the adapters

Fixes #29766.
Fixes #29804.
Fixes #26541.
Close #28539.
Close #24769.
Close #26468.
Close #26202.

There are probably other issues/PRs that can be closed because of this
commit, but that's all I could find on the first few pages.
This commit is contained in:
Sean Griffin 2017-07-24 08:19:35 -04:00
parent 0449d8b6bc
commit 213796fb49
37 changed files with 284 additions and 460 deletions

View File

@ -31,7 +31,7 @@ GIT
GIT
remote: https://github.com/rails/arel.git
revision: 67a51c62f4e19390cd8eb408596ca48bb0806362
revision: 7a29220c689feb0581e21d5324b85fc2f201ac5e
specs:
arel (8.0.0)

View File

@ -154,7 +154,7 @@ module ActiveRecord
stmt.from scope.klass.arel_table
stmt.wheres = arel.constraints
count = scope.klass.connection.delete(stmt, "SQL", scope.bound_attributes)
count = scope.klass.connection.delete(stmt, "SQL")
end
when :nullify
count = scope.update_all(source_reflection.foreign_key => nil)

View File

@ -23,11 +23,10 @@ module ActiveRecord
super && reflection == other.reflection
end
JoinInformation = Struct.new :joins, :binds
JoinInformation = Struct.new :joins
def join_constraints(foreign_table, foreign_klass, join_type, tables, chain)
joins = []
binds = []
tables = tables.reverse
# The chain starts with the target table, but we want to end with it here (makes
@ -43,7 +42,6 @@ module ActiveRecord
join_scope = reflection.join_scope(table, foreign_klass)
if join_scope.arel.constraints.any?
binds.concat join_scope.bound_attributes
joins.concat join_scope.arel.join_sources
right = joins.last.right
right.expr = right.expr.and(join_scope.arel.constraints)
@ -53,7 +51,7 @@ module ActiveRecord
foreign_table, foreign_klass = table, klass
end
JoinInformation.new joins, binds
JoinInformation.new joins
end
def table

View File

@ -29,7 +29,7 @@ module ActiveRecord
arel = query.arel
end
result = connection.select_one(arel, nil, query.bound_attributes)
result = connection.select_one(arel, nil)
if result.blank?
size = 0

View File

@ -9,30 +9,36 @@ module ActiveRecord
end
# Converts an arel AST to SQL
def to_sql(arel, binds = [])
if arel.respond_to?(:ast)
collected = visitor.accept(arel.ast, collector)
collected.compile(binds, self).freeze
def to_sql(arel_or_sql_string, binds = [])
if arel_or_sql_string.respond_to?(:ast)
unless binds.empty?
raise "Passing bind parameters with an arel AST is forbidden. " \
"The values must be stored on the AST directly"
end
sql, binds = visitor.accept(arel_or_sql_string.ast, collector).value
[sql.freeze, binds || []]
else
arel.dup.freeze
[arel_or_sql_string.dup.freeze, binds]
end
end
# This is used in the StatementCache object. It returns an object that
# can be used to query the database repeatedly.
def cacheable_query(klass, arel) # :nodoc:
collected = visitor.accept(arel.ast, collector)
if prepared_statements
klass.query(collected.value)
sql, binds = visitor.accept(arel.ast, collector).value
query = klass.query(sql)
else
klass.partial_query(collected.value)
query = klass.partial_query(arel.ast)
binds = []
end
[query, binds]
end
# Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil)
arel, binds = binds_from_relation arel, binds
sql = to_sql(arel, binds)
arel = arel_from_relation(arel)
sql, binds = to_sql(arel, binds)
if !prepared_statements || (arel.is_a?(String) && preparable.nil?)
preparable = false
else
@ -131,20 +137,23 @@ module ActiveRecord
#
# If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
value = exec_insert(to_sql(arel, binds), name, binds, pk, sequence_name)
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil)
sql, binds = to_sql(arel)
value = exec_insert(sql, name, binds, pk, sequence_name)
id_value || last_inserted_id(value)
end
alias create insert
# Executes the update statement and returns the number of rows affected.
def update(arel, name = nil, binds = [])
exec_update(to_sql(arel, binds), name, binds)
def update(arel, name = nil)
sql, binds = to_sql(arel)
exec_update(sql, name, binds)
end
# Executes the delete statement and returns the number of rows affected.
def delete(arel, name = nil, binds = [])
exec_delete(to_sql(arel, binds), name, binds)
def delete(arel, name = nil)
sql, binds = to_sql(arel)
exec_delete(sql, name, binds)
end
# Returns +true+ when the connection adapter supports prepared statement
@ -430,11 +439,12 @@ module ActiveRecord
row && row.first
end
def binds_from_relation(relation, binds)
if relation.is_a?(Relation) && binds.empty?
relation, binds = relation.arel, relation.bound_attributes
def arel_from_relation(relation)
if relation.is_a?(Relation)
relation.arel
else
relation
end
[relation, binds]
end
# Fixture value is quoted by Arel, however scalar values

View File

@ -92,8 +92,8 @@ module ActiveRecord
def select_all(arel, name = nil, binds = [], preparable: nil)
if @query_cache_enabled && !locked?(arel)
arel, binds = binds_from_relation arel, binds
sql = to_sql(arel, binds)
arel = arel_from_relation(arel)
sql, binds = to_sql(arel, binds)
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable) }
else
super

View File

@ -24,6 +24,10 @@ module ActiveRecord
return value.quoted_id
end
if value.respond_to?(:value_for_database)
value = value.value_for_database
end
_quote(value)
end

View File

@ -7,7 +7,9 @@ require_relative "sql_type_metadata"
require_relative "abstract/schema_dumper"
require_relative "abstract/schema_creation"
require "arel/collectors/bind"
require "arel/collectors/composite"
require "arel/collectors/sql_string"
require "arel/collectors/substitute_binds"
module ActiveRecord
module ConnectionAdapters # :nodoc:
@ -129,19 +131,6 @@ module ActiveRecord
end
end
class BindCollector < Arel::Collectors::Bind
def compile(bvs, conn)
casted_binds = bvs.map(&:value_for_database)
super(casted_binds.map { |value| conn.quote(value) })
end
end
class SQLString < Arel::Collectors::SQLString
def compile(bvs, conn)
super(bvs)
end
end
def valid_type?(type) # :nodoc:
!native_database_types[type].nil?
end
@ -432,14 +421,14 @@ module ActiveRecord
end
def case_sensitive_comparison(table, attribute, column, value) # :nodoc:
table[attribute].eq(value)
table[attribute].eq(Arel::Nodes::BindParam.new(value))
end
def case_insensitive_comparison(table, attribute, column, value) # :nodoc:
if can_perform_case_insensitive_comparison_for?(column)
table[attribute].lower.eq(table.lower(value))
table[attribute].lower.eq(table.lower(Arel::Nodes::BindParam.new(value)))
else
table[attribute].eq(value)
table[attribute].eq(Arel::Nodes::BindParam.new(value))
end
end
@ -457,24 +446,6 @@ module ActiveRecord
visitor.accept(node, collector).value
end
def combine_bind_parameters(
from_clause: [],
join_clause: [],
where_clause: [],
having_clause: [],
limit: nil,
offset: nil
) # :nodoc:
result = from_clause + join_clause + where_clause + having_clause
if limit
result << limit
end
if offset
result << offset
end
result
end
def default_index_type?(index) # :nodoc:
index.using.nil?
end
@ -609,9 +580,15 @@ module ActiveRecord
def collector
if prepared_statements
SQLString.new
Arel::Collectors::Composite.new(
Arel::Collectors::SQLString.new,
Arel::Collectors::Bind.new,
)
else
BindCollector.new
Arel::Collectors::SubstituteBinds.new(
self,
Arel::Collectors::SQLString.new,
)
end
end

View File

@ -177,7 +177,8 @@ module ActiveRecord
#++
def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel, binds)}"
sql, binds = to_sql(arel, binds)
sql = "EXPLAIN #{sql}"
start = Time.now
result = exec_query(sql, "EXPLAIN", binds)
elapsed = Time.now - start

View File

@ -5,7 +5,7 @@ module ActiveRecord
module MySQL
module DatabaseStatements
# Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil) # :nodoc:
def select_all(*) # :nodoc:
result = if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super }
else

View File

@ -5,7 +5,8 @@ module ActiveRecord
module PostgreSQL
module DatabaseStatements
def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel, binds)}"
sql, binds = to_sql(arel, binds)
sql = "EXPLAIN #{sql}"
PostgreSQL::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", binds))
end

View File

@ -203,7 +203,8 @@ module ActiveRecord
#++
def explain(arel, binds = [])
sql = "EXPLAIN QUERY PLAN #{to_sql(arel, binds)}"
sql, binds = to_sql(arel, binds)
sql = "EXPLAIN QUERY PLAN #{sql}"
SQLite3::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", []))
end

View File

@ -53,7 +53,7 @@ module ActiveRecord
im = arel.create_insert
im.into @table
substitutes, binds = substitute_values values
substitutes = substitute_values values
if values.empty? # empty insert
im.values = Arel.sql(connection.empty_insert_statement_value)
@ -67,11 +67,11 @@ module ActiveRecord
primary_key || false,
primary_key_value,
nil,
binds)
)
end
def _update_record(values, id, id_was) # :nodoc:
substitutes, binds = substitute_values values
substitutes = substitute_values values
scope = @klass.unscoped
@ -80,7 +80,6 @@ module ActiveRecord
end
relation = scope.where(@klass.primary_key => (id_was || id))
bvs = binds + relation.bound_attributes
um = relation
.arel
.compile_update(substitutes, @klass.primary_key)
@ -88,20 +87,14 @@ module ActiveRecord
@klass.connection.update(
um,
"SQL",
bvs,
)
end
def substitute_values(values) # :nodoc:
binds = []
substitutes = []
values.each do |arel_attr, value|
binds.push QueryAttribute.new(arel_attr.name, value, klass.type_for_attribute(arel_attr.name))
substitutes.push [arel_attr, Arel::Nodes::BindParam.new]
values.map do |arel_attr, value|
bind = QueryAttribute.new(arel_attr.name, value, klass.type_for_attribute(arel_attr.name))
[arel_attr, Arel::Nodes::BindParam.new(bind)]
end
[substitutes, binds]
end
def arel_attribute(name) # :nodoc:
@ -380,7 +373,7 @@ module ActiveRecord
stmt.wheres = arel.constraints
end
@klass.connection.update stmt, "SQL", bound_attributes
@klass.connection.update stmt, "SQL"
end
# Updates an object (or multiple objects) and saves it to the database, if validations pass.
@ -510,7 +503,7 @@ module ActiveRecord
stmt.wheres = arel.constraints
end
affected = @klass.connection.delete(stmt, "SQL", bound_attributes)
affected = @klass.connection.delete(stmt, "SQL")
reset
affected
@ -578,7 +571,8 @@ module ActiveRecord
conn = klass.connection
conn.unprepared_statement {
conn.to_sql(relation.arel, relation.bound_attributes)
sql, _ = conn.to_sql(relation.arel)
sql
}
end
end
@ -663,7 +657,7 @@ module ActiveRecord
def exec_queries(&block)
skip_query_cache_if_necessary do
@records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, bound_attributes, &block).freeze
@records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, &block).freeze
preload = preload_values
preload += includes_values unless eager_loading?

View File

@ -186,7 +186,7 @@ module ActiveRecord
relation.select_values = column_names.map { |cn|
@klass.has_attribute?(cn) || @klass.attribute_alias?(cn) ? arel_attribute(cn) : cn
}
result = skip_query_cache_if_necessary { klass.connection.select_all(relation.arel, nil, bound_attributes) }
result = skip_query_cache_if_necessary { klass.connection.select_all(relation.arel, nil) }
result.cast_values(klass.attribute_types)
end
end
@ -262,7 +262,7 @@ module ActiveRecord
query_builder = relation.arel
end
result = skip_query_cache_if_necessary { @klass.connection.select_all(query_builder, nil, bound_attributes) }
result = skip_query_cache_if_necessary { @klass.connection.select_all(query_builder, nil) }
row = result.first
value = row && row.values.first
type = result.column_types.fetch(column_alias) do
@ -313,7 +313,7 @@ module ActiveRecord
relation.group_values = group_fields
relation.select_values = select_values
calculated_data = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, nil, relation.bound_attributes) }
calculated_data = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, nil) }
if association
key_ids = calculated_data.collect { |row| row[group_aliases.first] }

View File

@ -317,7 +317,7 @@ module ActiveRecord
relation = construct_relation_for_exists(relation, conditions)
skip_query_cache_if_necessary { connection.select_value(relation.arel, "#{name} Exists", relation.bound_attributes) } ? true : false
skip_query_cache_if_necessary { connection.select_value(relation.arel, "#{name} Exists") } ? true : false
rescue ::RangeError
false
end
@ -378,7 +378,7 @@ module ActiveRecord
if ActiveRecord::NullRelation === relation
[]
else
rows = skip_query_cache_if_necessary { connection.select_all(relation.arel, "SQL", relation.bound_attributes) }
rows = skip_query_cache_if_necessary { connection.select_all(relation.arel, "SQL") }
join_dependency.instantiate(rows, aliases)
end
end
@ -426,7 +426,7 @@ module ActiveRecord
relation = relation.except(:select).select(values).distinct!
id_rows = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, "SQL", relation.bound_attributes) }
id_rows = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, "SQL") }
id_rows.map { |row| row[primary_key] }
end

View File

@ -10,14 +10,6 @@ module ActiveRecord
@name = name
end
def binds
if value.is_a?(Relation)
value.bound_attributes
else
[]
end
end
def merge(other)
self
end

View File

@ -8,10 +8,9 @@ module ActiveRecord
@table = table
@handlers = []
register_handler(BasicObject, BasicObjectHandler.new)
register_handler(BasicObject, BasicObjectHandler.new(self))
register_handler(Base, BaseHandler.new(self))
register_handler(Range, RangeHandler.new)
register_handler(RangeHandler::RangeWithBinds, RangeHandler.new)
register_handler(Range, RangeHandler.new(self))
register_handler(Relation, RelationHandler.new)
register_handler(Array, ArrayHandler.new(self))
end
@ -21,11 +20,6 @@ module ActiveRecord
expand_from_hash(attributes)
end
def create_binds(attributes)
attributes = convert_dot_notation_to_hash(attributes)
create_binds_for_hash(attributes)
end
def self.references(attributes)
attributes.map do |key, value|
if value.is_a?(Hash)
@ -56,8 +50,11 @@ module ActiveRecord
handler_for(value).call(attribute, value)
end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
def build_bind_attribute(column_name, value)
attr = Relation::QueryAttribute.new(column_name.to_s, value, table.type(column_name))
Arel::Nodes::BindParam.new(attr)
end
protected
attr_reader :table
@ -68,29 +65,13 @@ module ActiveRecord
attributes.flat_map do |key, value|
if value.is_a?(Hash) && !table.has_column?(key)
associated_predicate_builder(key).expand_from_hash(value)
else
build(table.arel_attribute(key), value)
end
end
end
def create_binds_for_hash(attributes)
result = attributes.dup
binds = []
attributes.each do |column_name, value|
case
when value.is_a?(Hash) && !table.has_column?(column_name)
attrs, bvs = associated_predicate_builder(column_name).create_binds_for_hash(value)
result[column_name] = attrs
binds += bvs
when table.associated_with?(column_name)
elsif table.associated_with?(key)
# Find the foreign key when using queries such as:
# Post.where(author: author)
#
# For polymorphic relationships, find the foreign key and type:
# PriceEstimate.where(estimate_of: treasure)
associated_table = table.associated_table(column_name)
associated_table = table.associated_table(key)
if associated_table.polymorphic_association?
case value.is_a?(Array) ? value.first : value
when Base, Relation
@ -100,41 +81,15 @@ module ActiveRecord
end
klass ||= AssociationQueryValue
result[column_name] = klass.new(associated_table, value).queries.map do |query|
attrs, bvs = create_binds_for_hash(query)
binds.concat(bvs)
attrs
queries = klass.new(associated_table, value).queries.map do |query|
expand_from_hash(query).reduce(&:and)
end
when value.is_a?(Range) && !table.type(column_name).respond_to?(:subtype)
first = value.begin
last = value.end
unless first.respond_to?(:infinite?) && first.infinite?
binds << build_bind_attribute(column_name, first)
first = Arel::Nodes::BindParam.new
end
unless last.respond_to?(:infinite?) && last.infinite?
binds << build_bind_attribute(column_name, last)
last = Arel::Nodes::BindParam.new
end
result[column_name] = RangeHandler::RangeWithBinds.new(first, last, value.exclude_end?)
when value.is_a?(Relation)
binds.concat(value.bound_attributes)
queries.reduce(&:or)
else
if can_be_bound?(column_name, value)
bind_attribute = build_bind_attribute(column_name, value)
if value.is_a?(StatementCache::Substitute) || !bind_attribute.value_for_database.nil?
result[column_name] = Arel::Nodes::BindParam.new
binds << bind_attribute
else
result[column_name] = nil
build(table.arel_attribute(key), value)
end
end
end
end
[result, binds]
end
private
@ -161,19 +116,6 @@ module ActiveRecord
def handler_for(object)
@handlers.detect { |klass, _| klass === object }.last
end
def can_be_bound?(column_name, value)
case value
when Array, Range
table.type(column_name).respond_to?(:subtype)
else
!value.nil? && handler_for(value).is_a?(BasicObjectHandler)
end
end
def build_bind_attribute(column_name, value)
Relation::QueryAttribute.new(column_name.to_s, value, table.type(column_name))
end
end
end

View File

@ -19,7 +19,11 @@ module ActiveRecord
case values.length
when 0 then NullPredicate
when 1 then predicate_builder.build(attribute, values.first)
else attribute.in(values)
else
bind_values = values.map do |v|
predicate_builder.build_bind_attribute(attribute.name, v)
end
attribute.in(bind_values)
end
unless nils.empty?
@ -31,8 +35,6 @@ module ActiveRecord
array_predicates.inject(&:or)
end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected
attr_reader :predicate_builder

View File

@ -11,8 +11,6 @@ module ActiveRecord
predicate_builder.build(attribute, value.id)
end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected
attr_reader :predicate_builder

View File

@ -3,9 +3,18 @@
module ActiveRecord
class PredicateBuilder
class BasicObjectHandler # :nodoc:
def call(attribute, value)
attribute.eq(value)
def initialize(predicate_builder)
@predicate_builder = predicate_builder
end
def call(attribute, value)
bind = predicate_builder.build_bind_attribute(attribute.name, value)
attribute.eq(bind)
end
protected
attr_reader :predicate_builder
end
end
end

View File

@ -3,25 +3,39 @@
module ActiveRecord
class PredicateBuilder
class RangeHandler # :nodoc:
RangeWithBinds = Struct.new(:begin, :end, :exclude_end?)
class RangeWithBinds < Struct.new(:begin, :end)
def exclude_end?
false
end
end
def initialize(predicate_builder)
@predicate_builder = predicate_builder
end
def call(attribute, value)
begin_bind = predicate_builder.build_bind_attribute(attribute.name, value.begin)
end_bind = predicate_builder.build_bind_attribute(attribute.name, value.end)
if value.begin.respond_to?(:infinite?) && value.begin.infinite?
if value.end.respond_to?(:infinite?) && value.end.infinite?
attribute.not_in([])
elsif value.exclude_end?
attribute.lt(value.end)
attribute.lt(end_bind)
else
attribute.lteq(value.end)
attribute.lteq(end_bind)
end
elsif value.end.respond_to?(:infinite?) && value.end.infinite?
attribute.gteq(value.begin)
attribute.gteq(begin_bind)
elsif value.exclude_end?
attribute.gteq(value.begin).and(attribute.lt(value.end))
attribute.gteq(begin_bind).and(attribute.lt(end_bind))
else
attribute.between(value)
attribute.between(RangeWithBinds.new(begin_bind, end_bind))
end
end
protected
attr_reader :predicate_builder
end
end
end

View File

@ -16,6 +16,11 @@ module ActiveRecord
def with_cast_value(value)
QueryAttribute.new(name, value, type)
end
def nil?
!value_before_type_cast.is_a?(StatementCache::Substitute) &&
(value_before_type_cast.nil? || value_for_database.nil?)
end
end
end
end

View File

@ -76,31 +76,6 @@ module ActiveRecord
CODE
end
def bound_attributes
if limit_value
limit_bind = Attribute.with_cast_value(
"LIMIT".freeze,
connection.sanitize_limit(limit_value),
Type.default_value,
)
end
if offset_value
offset_bind = Attribute.with_cast_value(
"OFFSET".freeze,
offset_value.to_i,
Type.default_value,
)
end
connection.combine_bind_parameters(
from_clause: from_clause.binds,
join_clause: arel.bind_values,
where_clause: where_clause.binds,
having_clause: having_clause.binds,
limit: limit_bind,
offset: offset_bind,
)
end
alias extensions extending_values
# Specify relationships to be included in the result set. For
@ -952,8 +927,22 @@ module ActiveRecord
arel.where(where_clause.ast) unless where_clause.empty?
arel.having(having_clause.ast) unless having_clause.empty?
arel.take(Arel::Nodes::BindParam.new) if limit_value
arel.skip(Arel::Nodes::BindParam.new) if offset_value
if limit_value
limit_attribute = Attribute.with_cast_value(
"LIMIT".freeze,
connection.sanitize_limit(limit_value),
Type.default_value,
)
arel.take(Arel::Nodes::BindParam.new(limit_attribute))
end
if offset_value
offset_attribute = Attribute.with_cast_value(
"OFFSET".freeze,
offset_value.to_i,
Type.default_value,
)
arel.skip(Arel::Nodes::BindParam.new(offset_attribute))
end
arel.group(*arel_columns(group_values.uniq.reject(&:blank?))) unless group_values.empty?
build_order(arel)
@ -1029,7 +1018,6 @@ module ActiveRecord
join_infos.each do |info|
info.joins.each { |join| manager.from(join) }
manager.bind_values.concat info.binds
end
manager.join_sources.concat(join_list)

View File

@ -3,31 +3,26 @@
module ActiveRecord
class Relation
class WhereClause # :nodoc:
attr_reader :binds
delegate :any?, :empty?, to: :predicates
def initialize(predicates, binds)
def initialize(predicates)
@predicates = predicates
@binds = binds
end
def +(other)
WhereClause.new(
predicates + other.predicates,
binds + other.binds,
)
end
def merge(other)
WhereClause.new(
predicates_unreferenced_by(other) + other.predicates,
non_conflicting_binds(other) + other.binds,
)
end
def except(*columns)
WhereClause.new(*except_predicates_and_binds(columns))
WhereClause.new(except_predicates(columns))
end
def or(other)
@ -38,7 +33,6 @@ module ActiveRecord
else
WhereClause.new(
[ast.or(other.ast)],
binds + other.binds
)
end
end
@ -51,17 +45,10 @@ module ActiveRecord
end
end
binds = self.binds.map { |attr| [attr.name, attr.value] }.to_h
equalities.map { |node|
name = node.left.name.to_s
[name, binds.fetch(name) {
case node.right
when Array then node.right.map(&:val)
when Arel::Nodes::Casted, Arel::Nodes::Quoted
node.right.val
end
}]
value = extract_node_value(node.right)
[name, value]
}.to_h
end
@ -71,20 +58,17 @@ module ActiveRecord
def ==(other)
other.is_a?(WhereClause) &&
predicates == other.predicates &&
binds == other.binds
predicates == other.predicates
end
def invert
WhereClause.new(inverted_predicates, binds)
WhereClause.new(inverted_predicates)
end
def self.empty
@empty ||= new([], [])
@empty ||= new([])
end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected
attr_reader :predicates
@ -108,12 +92,6 @@ module ActiveRecord
node.respond_to?(:operator) && node.operator == :==
end
def non_conflicting_binds(other)
conflicts = referenced_columns & other.referenced_columns
conflicts.map! { |node| node.name.to_s }
binds.reject { |attr| conflicts.include?(attr.name) }
end
def inverted_predicates
predicates.map { |node| invert_predicate(node) }
end
@ -133,44 +111,22 @@ module ActiveRecord
end
end
def except_predicates_and_binds(columns)
except_binds = []
binds_index = 0
predicates = self.predicates.reject do |node|
binds_contains = node.grep(Arel::Nodes::BindParam).size if node.is_a?(Arel::Nodes::Node)
except = \
def except_predicates(columns)
self.predicates.reject do |node|
case node
when Arel::Nodes::Between, Arel::Nodes::In, Arel::Nodes::NotIn, Arel::Nodes::Equality, Arel::Nodes::NotEqual, Arel::Nodes::LessThan, Arel::Nodes::LessThanOrEqual, Arel::Nodes::GreaterThan, Arel::Nodes::GreaterThanOrEqual
subrelation = (node.left.kind_of?(Arel::Attributes::Attribute) ? node.left : node.right)
columns.include?(subrelation.name.to_s)
end
if except && binds_contains > 0
(binds_index...(binds_index + binds_contains)).each do |i|
except_binds[i] = true
end
end
binds_index += binds_contains if binds_contains
except
end
binds = self.binds.reject.with_index do |_, i|
except_binds[i]
end
[predicates, binds]
end
def predicates_with_wrapped_sql_literals
non_empty_predicates.map do |node|
if Arel::Nodes::Equality === node
node
else
case node
when Arel::Nodes::SqlLiteral, ::String
wrap_sql_literal(node)
else node
end
end
end
@ -186,6 +142,22 @@ module ActiveRecord
end
Arel::Nodes::Grouping.new(node)
end
def extract_node_value(node)
case node
when Array
node.map { |v| extract_node_value(v) }
when Arel::Nodes::Casted, Arel::Nodes::Quoted
node.val
when Arel::Nodes::BindParam
value = node.value
if value.respond_to?(:value_before_type_cast)
value.value_before_type_cast
else
value
end
end
end
end
end
end

View File

@ -17,63 +17,19 @@ module ActiveRecord
attributes = klass.send(:expand_hash_conditions_for_aggregates, attributes)
attributes.stringify_keys!
if perform_case_sensitive?(options = other.last)
parts, binds = build_for_case_sensitive(attributes, options)
else
attributes, binds = predicate_builder.create_binds(attributes)
parts = predicate_builder.build_from_hash(attributes)
end
when Arel::Nodes::Node
parts = [opts]
else
raise ArgumentError, "Unsupported argument type: #{opts} (#{opts.class})"
end
WhereClause.new(parts, binds || [])
WhereClause.new(parts)
end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected
attr_reader :klass, :predicate_builder
private
def perform_case_sensitive?(options)
options && options.key?(:case_sensitive)
end
def build_for_case_sensitive(attributes, options)
parts, binds = [], []
table = klass.arel_table
attributes.each do |attribute, value|
if reflection = klass._reflect_on_association(attribute)
attribute = reflection.foreign_key.to_s
value = value[reflection.klass.primary_key] unless value.nil?
end
if value.nil?
parts << table[attribute].eq(value)
else
column = klass.column_for_attribute(attribute)
binds << predicate_builder.send(:build_bind_attribute, attribute, value)
value = Arel::Nodes::BindParam.new
predicate = if options[:case_sensitive]
klass.connection.case_sensitive_comparison(table, attribute, column, value)
else
klass.connection.case_insensitive_comparison(table, attribute, column, value)
end
parts << predicate
end
end
[parts, binds]
end
end
end
end

View File

@ -41,18 +41,20 @@ module ActiveRecord
end
class PartialQuery < Query # :nodoc:
def initialize(values)
@values = values
@indexes = values.each_with_index.find_all { |thing, i|
Arel::Nodes::BindParam === thing
}.map(&:last)
def initialize(arel)
@arel = arel
end
def sql_for(binds, connection)
val = @values.dup
casted_binds = binds.map(&:value_for_database)
@indexes.each { |i| val[i] = connection.quote(casted_binds.shift) }
val.join
val = @arel.dup
val.grep(Arel::Nodes::BindParam) do |node|
node.value = binds.shift
if binds.empty?
break
end
end
sql, _ = connection.visitor.accept(val, connection.send(:collector)).value
sql
end
end
@ -91,8 +93,8 @@ module ActiveRecord
def self.create(connection, block = Proc.new)
relation = block.call Params.new
bind_map = BindMap.new relation.bound_attributes
query_builder = connection.cacheable_query(self, relation.arel)
query_builder, binds = connection.cacheable_query(self, relation.arel)
bind_map = BindMap.new(binds)
new query_builder, bind_map
end

View File

@ -52,7 +52,37 @@ module ActiveRecord
end
def build_relation(klass, attribute, value)
klass.unscoped.where!({ attribute => value }, options)
if reflection = klass._reflect_on_association(attribute)
attribute = reflection.foreign_key
value = value.attributes[reflection.klass.primary_key] unless value.nil?
end
if value.nil?
return klass.unscoped.where!(attribute => value)
end
# the attribute may be an aliased attribute
if klass.attribute_alias?(attribute)
attribute = klass.attribute_alias(attribute)
end
attribute_name = attribute.to_s
table = klass.arel_table
column = klass.columns_hash[attribute_name]
cast_type = klass.type_for_attribute(attribute_name)
value = Relation::QueryAttribute.new(attribute_name, value, cast_type)
comparison = if !options[:case_sensitive]
# will use SQL LOWER function before comparison, unless it detects a case insensitive collation
klass.connection.case_insensitive_comparison(table, attribute, column, value)
else
klass.connection.case_sensitive_comparison(table, attribute, column, value)
end
klass.unscoped.tap do |scope|
parts = [comparison]
scope.where_clause += Relation::WhereClause.new(parts)
end
end
def scope_relation(record, relation)

View File

@ -244,7 +244,7 @@ module ActiveRecord
def test_select_all_with_legacy_binds
post = Post.create!(title: "foo", body: "bar")
expected = @connection.select_all("SELECT * FROM posts WHERE id = #{post.id}")
result = @connection.select_all("SELECT * FROM posts WHERE id = #{Arel::Nodes::BindParam.new.to_sql}", nil, [[nil, post.id]])
result = @connection.select_all("SELECT * FROM posts WHERE id = #{Arel::Nodes::BindParam.new(nil).to_sql}", nil, [[nil, post.id]])
assert_equal expected.to_hash, result.to_hash
end
end
@ -253,7 +253,6 @@ module ActiveRecord
author = Author.create!(name: "john")
Post.create!(author: author, title: "foo", body: "bar")
query = author.posts.where(title: "foo").select(:title)
assert_equal({ "title" => "foo" }, @connection.select_one(query.arel, nil, query.bound_attributes))
assert_equal({ "title" => "foo" }, @connection.select_one(query))
assert @connection.select_all(query).is_a?(ActiveRecord::Result)
assert_equal "foo", @connection.select_value(query)
@ -263,7 +262,6 @@ module ActiveRecord
def test_select_methods_passing_a_relation
Post.create!(title: "foo", body: "bar")
query = Post.where(title: "foo").select(:title)
assert_equal({ "title" => "foo" }, @connection.select_one(query.arel, nil, query.bound_attributes))
assert_equal({ "title" => "foo" }, @connection.select_one(query))
assert @connection.select_all(query).is_a?(ActiveRecord::Result)
assert_equal "foo", @connection.select_value(query)

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
require "cases/helper"
require "models/post"
require "models/author"
module ActiveRecord
module Associations
class AssociationScopeTest < ActiveRecord::TestCase
test "does not duplicate conditions" do
scope = AssociationScope.scope(Author.new.association(:welcome_posts))
binds = scope.where_clause.binds.map(&:value)
assert_equal binds.uniq, binds
end
end
end
end

View File

@ -41,7 +41,7 @@ if ActiveRecord::Base.connection.prepared_statements
end
def test_binds_are_logged
sub = Arel::Nodes::BindParam.new
sub = Arel::Nodes::BindParam.new(1)
binds = [Relation::QueryAttribute.new("id", 1, Type::Value.new)]
sql = "select * from topics where id = #{sub.to_sql}"

View File

@ -420,7 +420,7 @@ class InheritanceTest < ActiveRecord::TestCase
def test_eager_load_belongs_to_primary_key_quoting
con = Account.connection
bind_param = Arel::Nodes::BindParam.new
bind_param = Arel::Nodes::BindParam.new(nil)
assert_sql(/#{con.quote_table_name('companies')}\.#{con.quote_column_name('id')} = (?:#{Regexp.quote(bind_param.to_sql)}|1)/) do
Account.all.merge!(includes: :firm).find(1)
end

View File

@ -84,10 +84,8 @@ class RelationMergingTest < ActiveRecord::TestCase
left = Post.where(title: "omg").where(comments_count: 1)
right = Post.where(title: "wtf").where(title: "bbq")
expected = [left.bound_attributes[1]] + right.bound_attributes
merged = left.merge(right)
assert_equal expected, merged.bound_attributes
assert_not_includes merged.to_sql, "omg"
assert_includes merged.to_sql, "wtf"
assert_includes merged.to_sql, "bbq"

View File

@ -27,7 +27,7 @@ module ActiveRecord
end
def test_association_not_eq
expected = Arel::Nodes::Grouping.new(Comment.arel_table[@name].not_eq(Arel::Nodes::BindParam.new))
expected = Comment.arel_table[@name].not_eq(Arel::Nodes::BindParam.new(1))
relation = Post.joins(:comments).where.not(comments: { title: "hello" })
assert_equal(expected.to_sql, relation.where_clause.ast.to_sql)
end

View File

@ -5,76 +5,82 @@ require "cases/helper"
class ActiveRecord::Relation
class WhereClauseTest < ActiveRecord::TestCase
test "+ combines two where clauses" do
first_clause = WhereClause.new([table["id"].eq(bind_param)], [["id", 1]])
second_clause = WhereClause.new([table["name"].eq(bind_param)], [["name", "Sean"]])
first_clause = WhereClause.new([table["id"].eq(bind_param(1))])
second_clause = WhereClause.new([table["name"].eq(bind_param("Sean"))])
combined = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)],
[["id", 1], ["name", "Sean"]],
[table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean"))],
)
assert_equal combined, first_clause + second_clause
end
test "+ is associative, but not commutative" do
a = WhereClause.new(["a"], ["bind a"])
b = WhereClause.new(["b"], ["bind b"])
c = WhereClause.new(["c"], ["bind c"])
a = WhereClause.new(["a"])
b = WhereClause.new(["b"])
c = WhereClause.new(["c"])
assert_equal a + (b + c), (a + b) + c
assert_not_equal a + b, b + a
end
test "an empty where clause is the identity value for +" do
clause = WhereClause.new([table["id"].eq(bind_param)], [["id", 1]])
clause = WhereClause.new([table["id"].eq(bind_param(1))])
assert_equal clause, clause + WhereClause.empty
end
test "merge combines two where clauses" do
a = WhereClause.new([table["id"].eq(1)], [])
b = WhereClause.new([table["name"].eq("Sean")], [])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")], [])
a = WhereClause.new([table["id"].eq(1)])
b = WhereClause.new([table["name"].eq("Sean")])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")])
assert_equal expected, a.merge(b)
end
test "merge keeps the right side, when two equality clauses reference the same column" do
a = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")], [])
b = WhereClause.new([table["name"].eq("Jim")], [])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Jim")], [])
a = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")])
b = WhereClause.new([table["name"].eq("Jim")])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Jim")])
assert_equal expected, a.merge(b)
end
test "merge removes bind parameters matching overlapping equality clauses" do
a = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)],
[attribute("id", 1), attribute("name", "Sean")],
[table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean"))],
)
b = WhereClause.new(
[table["name"].eq(bind_param)],
[attribute("name", "Jim")]
[table["name"].eq(bind_param("Jim"))],
)
expected = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)],
[attribute("id", 1), attribute("name", "Jim")],
[table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Jim"))],
)
assert_equal expected, a.merge(b)
end
test "merge allows for columns with the same name from different tables" do
skip "This is not possible as of 4.2, and the binds do not yet contain sufficient information for this to happen"
# We might be able to change the implementation to remove conflicts by index, rather than column name
table2 = Arel::Table.new("table2")
a = WhereClause.new(
[table["id"].eq(bind_param(1)), table2["id"].eq(bind_param(2))],
)
b = WhereClause.new(
[table["id"].eq(bind_param(3))],
)
expected = WhereClause.new(
[table2["id"].eq(bind_param(2)), table["id"].eq(bind_param(3))],
)
assert_equal expected, a.merge(b)
end
test "a clause knows if it is empty" do
assert WhereClause.empty.empty?
assert_not WhereClause.new(["anything"], []).empty?
assert_not WhereClause.new(["anything"]).empty?
end
test "invert cannot handle nil" do
where_clause = WhereClause.new([nil], [])
where_clause = WhereClause.new([nil])
assert_raises ArgumentError do
where_clause.invert
@ -88,13 +94,13 @@ class ActiveRecord::Relation
table["id"].eq(1),
"sql literal",
random_object
], [])
])
expected = WhereClause.new([
table["id"].not_in([1, 2, 3]),
table["id"].not_eq(1),
Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new("sql literal")),
Arel::Nodes::Not.new(random_object)
], [])
])
assert_equal expected, original.invert
end
@ -102,20 +108,17 @@ class ActiveRecord::Relation
test "except removes binary predicates referencing a given column" do
where_clause = WhereClause.new([
table["id"].in([1, 2, 3]),
table["name"].eq(bind_param),
table["age"].gteq(bind_param),
], [
attribute("name", "Sean"),
attribute("age", 30),
table["name"].eq(bind_param("Sean")),
table["age"].gteq(bind_param(30)),
])
expected = WhereClause.new([table["age"].gteq(bind_param)], [attribute("age", 30)])
expected = WhereClause.new([table["age"].gteq(bind_param(30))])
assert_equal expected, where_clause.except("id", "name")
end
test "except jumps over unhandled binds (like with OR) correctly" do
wcs = (0..9).map do |i|
WhereClause.new([table["id#{i}"].eq(bind_param)], [attribute("id#{i}", i)])
WhereClause.new([table["id#{i}"].eq(bind_param(i))])
end
wc = wcs[0] + wcs[1] + wcs[2].or(wcs[3]) + wcs[4] + wcs[5] + wcs[6].or(wcs[7]) + wcs[8] + wcs[9]
@ -123,18 +126,15 @@ class ActiveRecord::Relation
expected = wcs[0] + wcs[2].or(wcs[3]) + wcs[5] + wcs[6].or(wcs[7]) + wcs[9]
actual = wc.except("id1", "id2", "id4", "id7", "id8")
# Easier to read than the inspect of where_clause
assert_equal expected.ast.to_sql, actual.ast.to_sql
assert_equal expected.binds.map(&:value), actual.binds.map(&:value)
assert_equal expected, actual
end
test "ast groups its predicates with AND" do
predicates = [
table["id"].in([1, 2, 3]),
table["name"].eq(bind_param),
table["name"].eq(bind_param(nil)),
]
where_clause = WhereClause.new(predicates, [])
where_clause = WhereClause.new(predicates)
expected = Arel::Nodes::And.new(predicates)
assert_equal expected, where_clause.ast
@ -146,38 +146,36 @@ class ActiveRecord::Relation
table["id"].in([1, 2, 3]),
"foo = bar",
random_object,
], [])
])
expected = Arel::Nodes::And.new([
table["id"].in([1, 2, 3]),
Arel::Nodes::Grouping.new(Arel.sql("foo = bar")),
Arel::Nodes::Grouping.new(random_object),
random_object,
])
assert_equal expected, where_clause.ast
end
test "ast removes any empty strings" do
where_clause = WhereClause.new([table["id"].in([1, 2, 3])], [])
where_clause_with_empty = WhereClause.new([table["id"].in([1, 2, 3]), ""], [])
where_clause = WhereClause.new([table["id"].in([1, 2, 3])])
where_clause_with_empty = WhereClause.new([table["id"].in([1, 2, 3]), ""])
assert_equal where_clause.ast, where_clause_with_empty.ast
end
test "or joins the two clauses using OR" do
where_clause = WhereClause.new([table["id"].eq(bind_param)], [attribute("id", 1)])
other_clause = WhereClause.new([table["name"].eq(bind_param)], [attribute("name", "Sean")])
where_clause = WhereClause.new([table["id"].eq(bind_param(1))])
other_clause = WhereClause.new([table["name"].eq(bind_param("Sean"))])
expected_ast =
Arel::Nodes::Grouping.new(
Arel::Nodes::Or.new(table["id"].eq(bind_param), table["name"].eq(bind_param))
Arel::Nodes::Or.new(table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean")))
)
expected_binds = where_clause.binds + other_clause.binds
assert_equal expected_ast.to_sql, where_clause.or(other_clause).ast.to_sql
assert_equal expected_binds, where_clause.or(other_clause).binds
end
test "or returns an empty where clause when either side is empty" do
where_clause = WhereClause.new([table["id"].eq(bind_param)], [attribute("id", 1)])
where_clause = WhereClause.new([table["id"].eq(bind_param(1))])
assert_equal WhereClause.empty, where_clause.or(WhereClause.empty)
assert_equal WhereClause.empty, WhereClause.empty.or(where_clause)
@ -189,12 +187,8 @@ class ActiveRecord::Relation
Arel::Table.new("table")
end
def bind_param
Arel::Nodes::BindParam.new
end
def attribute(name, value)
ActiveRecord::Attribute.with_cast_value(name, value, ActiveRecord::Type::Value.new)
def bind_param(value)
Arel::Nodes::BindParam.new(value)
end
end
end

View File

@ -185,7 +185,7 @@ module ActiveRecord
relation = Relation.new(klass, :b, nil)
relation.merge!(where: ["foo = ?", "bar"])
assert_equal Relation::WhereClause.new(["foo = bar"], []), relation.where_clause
assert_equal Relation::WhereClause.new(["foo = bar"]), relation.where_clause
end
def test_merging_readonly_false

View File

@ -1761,7 +1761,7 @@ class RelationTest < ActiveRecord::TestCase
test "relations with cached arel can't be mutated [internal API]" do
relation = Post.all
relation.count
relation.arel
assert_raises(ActiveRecord::ImmutableRelation) { relation.limit!(5) }
assert_raises(ActiveRecord::ImmutableRelation) { relation.where!("1 = 2") }
@ -1860,33 +1860,6 @@ class RelationTest < ActiveRecord::TestCase
assert_equal 1, posts.unscope(where: :body).count
end
def test_unscope_removes_binds
left = Post.where(id: 20)
assert_equal 1, left.bound_attributes.length
relation = left.unscope(where: :id)
assert_equal [], relation.bound_attributes
end
def test_merging_removes_rhs_binds
left = Post.where(id: 20)
right = Post.where(id: [1, 2, 3, 4])
assert_equal 1, left.bound_attributes.length
merged = left.merge(right)
assert_equal [], merged.bound_attributes
end
def test_merging_keeps_lhs_binds
right = Post.where(id: 20)
left = Post.where(id: 10)
merged = left.merge(right)
assert_equal [20], merged.bound_attributes.map(&:value)
end
def test_locked_should_not_build_arel
posts = Post.locked
assert posts.locked?
@ -1897,24 +1870,6 @@ class RelationTest < ActiveRecord::TestCase
assert_equal "Thank you for the welcome,Thank you again for the welcome", Post.first.comments.join(",")
end
def test_connection_adapters_can_reorder_binds
posts = Post.limit(1).offset(2)
stubbed_connection = Post.connection.dup
def stubbed_connection.combine_bind_parameters(**kwargs)
offset = kwargs[:offset]
kwargs[:offset] = kwargs[:limit]
kwargs[:limit] = offset
super(**kwargs)
end
posts.define_singleton_method(:connection) do
stubbed_connection
end
assert_equal 2, posts.to_a.length
end
test "#skip_query_cache!" do
Post.cache do
assert_queries(1) do