Refactor Active Record to let Arel manage bind params

A common source of bugs and code bloat within Active Record has been the
need for us to maintain the list of bind values separately from the AST
they're associated with. This makes any sort of AST manipulation
incredibly difficult, as any time we want to potentially insert or
remove an AST node, we need to traverse the entire tree to find where
the associated bind parameters are.

With this change, the bind parameters now live on the AST directly.
Active Record does not need to know or care about them until the final
AST traversal for SQL construction. Rather than returning just the SQL,
the Arel collector will now return both the SQL and the bind parameters.
At this point the connection adapter will have all the values that it
had before.

A bit of this code is janky and something I'd like to refactor later. In
particular, I don't like how we're handling associations in the
predicate builder, the special casing of `StatementCache::Substitute` in
`QueryAttribute`, or generally how we're handling bind value replacement
in the statement cache when prepared statements are disabled.

This also mostly reverts #26378, as it moved all the code into a
location that I wanted to delete.

/cc @metaskills @yahonda, this change will affect the adapters

Fixes #29766.
Fixes #29804.
Fixes #26541.
Close #28539.
Close #24769.
Close #26468.
Close #26202.

There are probably other issues/PRs that can be closed because of this
commit, but that's all I could find on the first few pages.
This commit is contained in:
Sean Griffin 2017-07-24 08:19:35 -04:00
parent 0449d8b6bc
commit 213796fb49
37 changed files with 284 additions and 460 deletions

View File

@ -31,7 +31,7 @@ GIT
GIT GIT
remote: https://github.com/rails/arel.git remote: https://github.com/rails/arel.git
revision: 67a51c62f4e19390cd8eb408596ca48bb0806362 revision: 7a29220c689feb0581e21d5324b85fc2f201ac5e
specs: specs:
arel (8.0.0) arel (8.0.0)

View File

@ -154,7 +154,7 @@ module ActiveRecord
stmt.from scope.klass.arel_table stmt.from scope.klass.arel_table
stmt.wheres = arel.constraints stmt.wheres = arel.constraints
count = scope.klass.connection.delete(stmt, "SQL", scope.bound_attributes) count = scope.klass.connection.delete(stmt, "SQL")
end end
when :nullify when :nullify
count = scope.update_all(source_reflection.foreign_key => nil) count = scope.update_all(source_reflection.foreign_key => nil)

View File

@ -23,11 +23,10 @@ module ActiveRecord
super && reflection == other.reflection super && reflection == other.reflection
end end
JoinInformation = Struct.new :joins, :binds JoinInformation = Struct.new :joins
def join_constraints(foreign_table, foreign_klass, join_type, tables, chain) def join_constraints(foreign_table, foreign_klass, join_type, tables, chain)
joins = [] joins = []
binds = []
tables = tables.reverse tables = tables.reverse
# The chain starts with the target table, but we want to end with it here (makes # The chain starts with the target table, but we want to end with it here (makes
@ -43,7 +42,6 @@ module ActiveRecord
join_scope = reflection.join_scope(table, foreign_klass) join_scope = reflection.join_scope(table, foreign_klass)
if join_scope.arel.constraints.any? if join_scope.arel.constraints.any?
binds.concat join_scope.bound_attributes
joins.concat join_scope.arel.join_sources joins.concat join_scope.arel.join_sources
right = joins.last.right right = joins.last.right
right.expr = right.expr.and(join_scope.arel.constraints) right.expr = right.expr.and(join_scope.arel.constraints)
@ -53,7 +51,7 @@ module ActiveRecord
foreign_table, foreign_klass = table, klass foreign_table, foreign_klass = table, klass
end end
JoinInformation.new joins, binds JoinInformation.new joins
end end
def table def table

View File

@ -29,7 +29,7 @@ module ActiveRecord
arel = query.arel arel = query.arel
end end
result = connection.select_one(arel, nil, query.bound_attributes) result = connection.select_one(arel, nil)
if result.blank? if result.blank?
size = 0 size = 0

View File

@ -9,30 +9,36 @@ module ActiveRecord
end end
# Converts an arel AST to SQL # Converts an arel AST to SQL
def to_sql(arel, binds = []) def to_sql(arel_or_sql_string, binds = [])
if arel.respond_to?(:ast) if arel_or_sql_string.respond_to?(:ast)
collected = visitor.accept(arel.ast, collector) unless binds.empty?
collected.compile(binds, self).freeze raise "Passing bind parameters with an arel AST is forbidden. " \
"The values must be stored on the AST directly"
end
sql, binds = visitor.accept(arel_or_sql_string.ast, collector).value
[sql.freeze, binds || []]
else else
arel.dup.freeze [arel_or_sql_string.dup.freeze, binds]
end end
end end
# This is used in the StatementCache object. It returns an object that # This is used in the StatementCache object. It returns an object that
# can be used to query the database repeatedly. # can be used to query the database repeatedly.
def cacheable_query(klass, arel) # :nodoc: def cacheable_query(klass, arel) # :nodoc:
collected = visitor.accept(arel.ast, collector)
if prepared_statements if prepared_statements
klass.query(collected.value) sql, binds = visitor.accept(arel.ast, collector).value
query = klass.query(sql)
else else
klass.partial_query(collected.value) query = klass.partial_query(arel.ast)
binds = []
end end
[query, binds]
end end
# Returns an ActiveRecord::Result instance. # Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil) def select_all(arel, name = nil, binds = [], preparable: nil)
arel, binds = binds_from_relation arel, binds arel = arel_from_relation(arel)
sql = to_sql(arel, binds) sql, binds = to_sql(arel, binds)
if !prepared_statements || (arel.is_a?(String) && preparable.nil?) if !prepared_statements || (arel.is_a?(String) && preparable.nil?)
preparable = false preparable = false
else else
@ -131,20 +137,23 @@ module ActiveRecord
# #
# If the next id was calculated in advance (as in Oracle), it should be # If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+. # passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = []) def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil)
value = exec_insert(to_sql(arel, binds), name, binds, pk, sequence_name) sql, binds = to_sql(arel)
value = exec_insert(sql, name, binds, pk, sequence_name)
id_value || last_inserted_id(value) id_value || last_inserted_id(value)
end end
alias create insert alias create insert
# Executes the update statement and returns the number of rows affected. # Executes the update statement and returns the number of rows affected.
def update(arel, name = nil, binds = []) def update(arel, name = nil)
exec_update(to_sql(arel, binds), name, binds) sql, binds = to_sql(arel)
exec_update(sql, name, binds)
end end
# Executes the delete statement and returns the number of rows affected. # Executes the delete statement and returns the number of rows affected.
def delete(arel, name = nil, binds = []) def delete(arel, name = nil)
exec_delete(to_sql(arel, binds), name, binds) sql, binds = to_sql(arel)
exec_delete(sql, name, binds)
end end
# Returns +true+ when the connection adapter supports prepared statement # Returns +true+ when the connection adapter supports prepared statement
@ -430,11 +439,12 @@ module ActiveRecord
row && row.first row && row.first
end end
def binds_from_relation(relation, binds) def arel_from_relation(relation)
if relation.is_a?(Relation) && binds.empty? if relation.is_a?(Relation)
relation, binds = relation.arel, relation.bound_attributes relation.arel
else
relation
end end
[relation, binds]
end end
# Fixture value is quoted by Arel, however scalar values # Fixture value is quoted by Arel, however scalar values

View File

@ -92,8 +92,8 @@ module ActiveRecord
def select_all(arel, name = nil, binds = [], preparable: nil) def select_all(arel, name = nil, binds = [], preparable: nil)
if @query_cache_enabled && !locked?(arel) if @query_cache_enabled && !locked?(arel)
arel, binds = binds_from_relation arel, binds arel = arel_from_relation(arel)
sql = to_sql(arel, binds) sql, binds = to_sql(arel, binds)
cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable) } cache_sql(sql, name, binds) { super(sql, name, binds, preparable: preparable) }
else else
super super

View File

@ -24,6 +24,10 @@ module ActiveRecord
return value.quoted_id return value.quoted_id
end end
if value.respond_to?(:value_for_database)
value = value.value_for_database
end
_quote(value) _quote(value)
end end

View File

@ -7,7 +7,9 @@ require_relative "sql_type_metadata"
require_relative "abstract/schema_dumper" require_relative "abstract/schema_dumper"
require_relative "abstract/schema_creation" require_relative "abstract/schema_creation"
require "arel/collectors/bind" require "arel/collectors/bind"
require "arel/collectors/composite"
require "arel/collectors/sql_string" require "arel/collectors/sql_string"
require "arel/collectors/substitute_binds"
module ActiveRecord module ActiveRecord
module ConnectionAdapters # :nodoc: module ConnectionAdapters # :nodoc:
@ -129,19 +131,6 @@ module ActiveRecord
end end
end end
class BindCollector < Arel::Collectors::Bind
def compile(bvs, conn)
casted_binds = bvs.map(&:value_for_database)
super(casted_binds.map { |value| conn.quote(value) })
end
end
class SQLString < Arel::Collectors::SQLString
def compile(bvs, conn)
super(bvs)
end
end
def valid_type?(type) # :nodoc: def valid_type?(type) # :nodoc:
!native_database_types[type].nil? !native_database_types[type].nil?
end end
@ -432,14 +421,14 @@ module ActiveRecord
end end
def case_sensitive_comparison(table, attribute, column, value) # :nodoc: def case_sensitive_comparison(table, attribute, column, value) # :nodoc:
table[attribute].eq(value) table[attribute].eq(Arel::Nodes::BindParam.new(value))
end end
def case_insensitive_comparison(table, attribute, column, value) # :nodoc: def case_insensitive_comparison(table, attribute, column, value) # :nodoc:
if can_perform_case_insensitive_comparison_for?(column) if can_perform_case_insensitive_comparison_for?(column)
table[attribute].lower.eq(table.lower(value)) table[attribute].lower.eq(table.lower(Arel::Nodes::BindParam.new(value)))
else else
table[attribute].eq(value) table[attribute].eq(Arel::Nodes::BindParam.new(value))
end end
end end
@ -457,24 +446,6 @@ module ActiveRecord
visitor.accept(node, collector).value visitor.accept(node, collector).value
end end
def combine_bind_parameters(
from_clause: [],
join_clause: [],
where_clause: [],
having_clause: [],
limit: nil,
offset: nil
) # :nodoc:
result = from_clause + join_clause + where_clause + having_clause
if limit
result << limit
end
if offset
result << offset
end
result
end
def default_index_type?(index) # :nodoc: def default_index_type?(index) # :nodoc:
index.using.nil? index.using.nil?
end end
@ -609,9 +580,15 @@ module ActiveRecord
def collector def collector
if prepared_statements if prepared_statements
SQLString.new Arel::Collectors::Composite.new(
Arel::Collectors::SQLString.new,
Arel::Collectors::Bind.new,
)
else else
BindCollector.new Arel::Collectors::SubstituteBinds.new(
self,
Arel::Collectors::SQLString.new,
)
end end
end end

View File

@ -177,7 +177,8 @@ module ActiveRecord
#++ #++
def explain(arel, binds = []) def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel, binds)}" sql, binds = to_sql(arel, binds)
sql = "EXPLAIN #{sql}"
start = Time.now start = Time.now
result = exec_query(sql, "EXPLAIN", binds) result = exec_query(sql, "EXPLAIN", binds)
elapsed = Time.now - start elapsed = Time.now - start

View File

@ -5,7 +5,7 @@ module ActiveRecord
module MySQL module MySQL
module DatabaseStatements module DatabaseStatements
# Returns an ActiveRecord::Result instance. # Returns an ActiveRecord::Result instance.
def select_all(arel, name = nil, binds = [], preparable: nil) # :nodoc: def select_all(*) # :nodoc:
result = if ExplainRegistry.collect? && prepared_statements result = if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super } unprepared_statement { super }
else else

View File

@ -5,7 +5,8 @@ module ActiveRecord
module PostgreSQL module PostgreSQL
module DatabaseStatements module DatabaseStatements
def explain(arel, binds = []) def explain(arel, binds = [])
sql = "EXPLAIN #{to_sql(arel, binds)}" sql, binds = to_sql(arel, binds)
sql = "EXPLAIN #{sql}"
PostgreSQL::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", binds)) PostgreSQL::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", binds))
end end

View File

@ -203,7 +203,8 @@ module ActiveRecord
#++ #++
def explain(arel, binds = []) def explain(arel, binds = [])
sql = "EXPLAIN QUERY PLAN #{to_sql(arel, binds)}" sql, binds = to_sql(arel, binds)
sql = "EXPLAIN QUERY PLAN #{sql}"
SQLite3::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", [])) SQLite3::ExplainPrettyPrinter.new.pp(exec_query(sql, "EXPLAIN", []))
end end

View File

@ -53,7 +53,7 @@ module ActiveRecord
im = arel.create_insert im = arel.create_insert
im.into @table im.into @table
substitutes, binds = substitute_values values substitutes = substitute_values values
if values.empty? # empty insert if values.empty? # empty insert
im.values = Arel.sql(connection.empty_insert_statement_value) im.values = Arel.sql(connection.empty_insert_statement_value)
@ -67,11 +67,11 @@ module ActiveRecord
primary_key || false, primary_key || false,
primary_key_value, primary_key_value,
nil, nil,
binds) )
end end
def _update_record(values, id, id_was) # :nodoc: def _update_record(values, id, id_was) # :nodoc:
substitutes, binds = substitute_values values substitutes = substitute_values values
scope = @klass.unscoped scope = @klass.unscoped
@ -80,7 +80,6 @@ module ActiveRecord
end end
relation = scope.where(@klass.primary_key => (id_was || id)) relation = scope.where(@klass.primary_key => (id_was || id))
bvs = binds + relation.bound_attributes
um = relation um = relation
.arel .arel
.compile_update(substitutes, @klass.primary_key) .compile_update(substitutes, @klass.primary_key)
@ -88,20 +87,14 @@ module ActiveRecord
@klass.connection.update( @klass.connection.update(
um, um,
"SQL", "SQL",
bvs,
) )
end end
def substitute_values(values) # :nodoc: def substitute_values(values) # :nodoc:
binds = [] values.map do |arel_attr, value|
substitutes = [] bind = QueryAttribute.new(arel_attr.name, value, klass.type_for_attribute(arel_attr.name))
[arel_attr, Arel::Nodes::BindParam.new(bind)]
values.each do |arel_attr, value|
binds.push QueryAttribute.new(arel_attr.name, value, klass.type_for_attribute(arel_attr.name))
substitutes.push [arel_attr, Arel::Nodes::BindParam.new]
end end
[substitutes, binds]
end end
def arel_attribute(name) # :nodoc: def arel_attribute(name) # :nodoc:
@ -380,7 +373,7 @@ module ActiveRecord
stmt.wheres = arel.constraints stmt.wheres = arel.constraints
end end
@klass.connection.update stmt, "SQL", bound_attributes @klass.connection.update stmt, "SQL"
end end
# Updates an object (or multiple objects) and saves it to the database, if validations pass. # Updates an object (or multiple objects) and saves it to the database, if validations pass.
@ -510,7 +503,7 @@ module ActiveRecord
stmt.wheres = arel.constraints stmt.wheres = arel.constraints
end end
affected = @klass.connection.delete(stmt, "SQL", bound_attributes) affected = @klass.connection.delete(stmt, "SQL")
reset reset
affected affected
@ -578,7 +571,8 @@ module ActiveRecord
conn = klass.connection conn = klass.connection
conn.unprepared_statement { conn.unprepared_statement {
conn.to_sql(relation.arel, relation.bound_attributes) sql, _ = conn.to_sql(relation.arel)
sql
} }
end end
end end
@ -663,7 +657,7 @@ module ActiveRecord
def exec_queries(&block) def exec_queries(&block)
skip_query_cache_if_necessary do skip_query_cache_if_necessary do
@records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, bound_attributes, &block).freeze @records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, &block).freeze
preload = preload_values preload = preload_values
preload += includes_values unless eager_loading? preload += includes_values unless eager_loading?

View File

@ -186,7 +186,7 @@ module ActiveRecord
relation.select_values = column_names.map { |cn| relation.select_values = column_names.map { |cn|
@klass.has_attribute?(cn) || @klass.attribute_alias?(cn) ? arel_attribute(cn) : cn @klass.has_attribute?(cn) || @klass.attribute_alias?(cn) ? arel_attribute(cn) : cn
} }
result = skip_query_cache_if_necessary { klass.connection.select_all(relation.arel, nil, bound_attributes) } result = skip_query_cache_if_necessary { klass.connection.select_all(relation.arel, nil) }
result.cast_values(klass.attribute_types) result.cast_values(klass.attribute_types)
end end
end end
@ -262,7 +262,7 @@ module ActiveRecord
query_builder = relation.arel query_builder = relation.arel
end end
result = skip_query_cache_if_necessary { @klass.connection.select_all(query_builder, nil, bound_attributes) } result = skip_query_cache_if_necessary { @klass.connection.select_all(query_builder, nil) }
row = result.first row = result.first
value = row && row.values.first value = row && row.values.first
type = result.column_types.fetch(column_alias) do type = result.column_types.fetch(column_alias) do
@ -313,7 +313,7 @@ module ActiveRecord
relation.group_values = group_fields relation.group_values = group_fields
relation.select_values = select_values relation.select_values = select_values
calculated_data = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, nil, relation.bound_attributes) } calculated_data = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, nil) }
if association if association
key_ids = calculated_data.collect { |row| row[group_aliases.first] } key_ids = calculated_data.collect { |row| row[group_aliases.first] }

View File

@ -317,7 +317,7 @@ module ActiveRecord
relation = construct_relation_for_exists(relation, conditions) relation = construct_relation_for_exists(relation, conditions)
skip_query_cache_if_necessary { connection.select_value(relation.arel, "#{name} Exists", relation.bound_attributes) } ? true : false skip_query_cache_if_necessary { connection.select_value(relation.arel, "#{name} Exists") } ? true : false
rescue ::RangeError rescue ::RangeError
false false
end end
@ -378,7 +378,7 @@ module ActiveRecord
if ActiveRecord::NullRelation === relation if ActiveRecord::NullRelation === relation
[] []
else else
rows = skip_query_cache_if_necessary { connection.select_all(relation.arel, "SQL", relation.bound_attributes) } rows = skip_query_cache_if_necessary { connection.select_all(relation.arel, "SQL") }
join_dependency.instantiate(rows, aliases) join_dependency.instantiate(rows, aliases)
end end
end end
@ -426,7 +426,7 @@ module ActiveRecord
relation = relation.except(:select).select(values).distinct! relation = relation.except(:select).select(values).distinct!
id_rows = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, "SQL", relation.bound_attributes) } id_rows = skip_query_cache_if_necessary { @klass.connection.select_all(relation.arel, "SQL") }
id_rows.map { |row| row[primary_key] } id_rows.map { |row| row[primary_key] }
end end

View File

@ -10,14 +10,6 @@ module ActiveRecord
@name = name @name = name
end end
def binds
if value.is_a?(Relation)
value.bound_attributes
else
[]
end
end
def merge(other) def merge(other)
self self
end end

View File

@ -8,10 +8,9 @@ module ActiveRecord
@table = table @table = table
@handlers = [] @handlers = []
register_handler(BasicObject, BasicObjectHandler.new) register_handler(BasicObject, BasicObjectHandler.new(self))
register_handler(Base, BaseHandler.new(self)) register_handler(Base, BaseHandler.new(self))
register_handler(Range, RangeHandler.new) register_handler(Range, RangeHandler.new(self))
register_handler(RangeHandler::RangeWithBinds, RangeHandler.new)
register_handler(Relation, RelationHandler.new) register_handler(Relation, RelationHandler.new)
register_handler(Array, ArrayHandler.new(self)) register_handler(Array, ArrayHandler.new(self))
end end
@ -21,11 +20,6 @@ module ActiveRecord
expand_from_hash(attributes) expand_from_hash(attributes)
end end
def create_binds(attributes)
attributes = convert_dot_notation_to_hash(attributes)
create_binds_for_hash(attributes)
end
def self.references(attributes) def self.references(attributes)
attributes.map do |key, value| attributes.map do |key, value|
if value.is_a?(Hash) if value.is_a?(Hash)
@ -56,8 +50,11 @@ module ActiveRecord
handler_for(value).call(attribute, value) handler_for(value).call(attribute, value)
end end
# TODO Change this to private once we've dropped Ruby 2.2 support. def build_bind_attribute(column_name, value)
# Workaround for Ruby 2.2 "private attribute?" warning. attr = Relation::QueryAttribute.new(column_name.to_s, value, table.type(column_name))
Arel::Nodes::BindParam.new(attr)
end
protected protected
attr_reader :table attr_reader :table
@ -68,29 +65,13 @@ module ActiveRecord
attributes.flat_map do |key, value| attributes.flat_map do |key, value|
if value.is_a?(Hash) && !table.has_column?(key) if value.is_a?(Hash) && !table.has_column?(key)
associated_predicate_builder(key).expand_from_hash(value) associated_predicate_builder(key).expand_from_hash(value)
else elsif table.associated_with?(key)
build(table.arel_attribute(key), value)
end
end
end
def create_binds_for_hash(attributes)
result = attributes.dup
binds = []
attributes.each do |column_name, value|
case
when value.is_a?(Hash) && !table.has_column?(column_name)
attrs, bvs = associated_predicate_builder(column_name).create_binds_for_hash(value)
result[column_name] = attrs
binds += bvs
when table.associated_with?(column_name)
# Find the foreign key when using queries such as: # Find the foreign key when using queries such as:
# Post.where(author: author) # Post.where(author: author)
# #
# For polymorphic relationships, find the foreign key and type: # For polymorphic relationships, find the foreign key and type:
# PriceEstimate.where(estimate_of: treasure) # PriceEstimate.where(estimate_of: treasure)
associated_table = table.associated_table(column_name) associated_table = table.associated_table(key)
if associated_table.polymorphic_association? if associated_table.polymorphic_association?
case value.is_a?(Array) ? value.first : value case value.is_a?(Array) ? value.first : value
when Base, Relation when Base, Relation
@ -100,40 +81,14 @@ module ActiveRecord
end end
klass ||= AssociationQueryValue klass ||= AssociationQueryValue
result[column_name] = klass.new(associated_table, value).queries.map do |query| queries = klass.new(associated_table, value).queries.map do |query|
attrs, bvs = create_binds_for_hash(query) expand_from_hash(query).reduce(&:and)
binds.concat(bvs)
attrs
end end
when value.is_a?(Range) && !table.type(column_name).respond_to?(:subtype) queries.reduce(&:or)
first = value.begin
last = value.end
unless first.respond_to?(:infinite?) && first.infinite?
binds << build_bind_attribute(column_name, first)
first = Arel::Nodes::BindParam.new
end
unless last.respond_to?(:infinite?) && last.infinite?
binds << build_bind_attribute(column_name, last)
last = Arel::Nodes::BindParam.new
end
result[column_name] = RangeHandler::RangeWithBinds.new(first, last, value.exclude_end?)
when value.is_a?(Relation)
binds.concat(value.bound_attributes)
else else
if can_be_bound?(column_name, value) build(table.arel_attribute(key), value)
bind_attribute = build_bind_attribute(column_name, value)
if value.is_a?(StatementCache::Substitute) || !bind_attribute.value_for_database.nil?
result[column_name] = Arel::Nodes::BindParam.new
binds << bind_attribute
else
result[column_name] = nil
end
end
end end
end end
[result, binds]
end end
private private
@ -161,19 +116,6 @@ module ActiveRecord
def handler_for(object) def handler_for(object)
@handlers.detect { |klass, _| klass === object }.last @handlers.detect { |klass, _| klass === object }.last
end end
def can_be_bound?(column_name, value)
case value
when Array, Range
table.type(column_name).respond_to?(:subtype)
else
!value.nil? && handler_for(value).is_a?(BasicObjectHandler)
end
end
def build_bind_attribute(column_name, value)
Relation::QueryAttribute.new(column_name.to_s, value, table.type(column_name))
end
end end
end end

View File

@ -19,7 +19,11 @@ module ActiveRecord
case values.length case values.length
when 0 then NullPredicate when 0 then NullPredicate
when 1 then predicate_builder.build(attribute, values.first) when 1 then predicate_builder.build(attribute, values.first)
else attribute.in(values) else
bind_values = values.map do |v|
predicate_builder.build_bind_attribute(attribute.name, v)
end
attribute.in(bind_values)
end end
unless nils.empty? unless nils.empty?
@ -31,8 +35,6 @@ module ActiveRecord
array_predicates.inject(&:or) array_predicates.inject(&:or)
end end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected protected
attr_reader :predicate_builder attr_reader :predicate_builder

View File

@ -11,8 +11,6 @@ module ActiveRecord
predicate_builder.build(attribute, value.id) predicate_builder.build(attribute, value.id)
end end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected protected
attr_reader :predicate_builder attr_reader :predicate_builder

View File

@ -3,9 +3,18 @@
module ActiveRecord module ActiveRecord
class PredicateBuilder class PredicateBuilder
class BasicObjectHandler # :nodoc: class BasicObjectHandler # :nodoc:
def call(attribute, value) def initialize(predicate_builder)
attribute.eq(value) @predicate_builder = predicate_builder
end end
def call(attribute, value)
bind = predicate_builder.build_bind_attribute(attribute.name, value)
attribute.eq(bind)
end
protected
attr_reader :predicate_builder
end end
end end
end end

View File

@ -3,25 +3,39 @@
module ActiveRecord module ActiveRecord
class PredicateBuilder class PredicateBuilder
class RangeHandler # :nodoc: class RangeHandler # :nodoc:
RangeWithBinds = Struct.new(:begin, :end, :exclude_end?) class RangeWithBinds < Struct.new(:begin, :end)
def exclude_end?
false
end
end
def initialize(predicate_builder)
@predicate_builder = predicate_builder
end
def call(attribute, value) def call(attribute, value)
begin_bind = predicate_builder.build_bind_attribute(attribute.name, value.begin)
end_bind = predicate_builder.build_bind_attribute(attribute.name, value.end)
if value.begin.respond_to?(:infinite?) && value.begin.infinite? if value.begin.respond_to?(:infinite?) && value.begin.infinite?
if value.end.respond_to?(:infinite?) && value.end.infinite? if value.end.respond_to?(:infinite?) && value.end.infinite?
attribute.not_in([]) attribute.not_in([])
elsif value.exclude_end? elsif value.exclude_end?
attribute.lt(value.end) attribute.lt(end_bind)
else else
attribute.lteq(value.end) attribute.lteq(end_bind)
end end
elsif value.end.respond_to?(:infinite?) && value.end.infinite? elsif value.end.respond_to?(:infinite?) && value.end.infinite?
attribute.gteq(value.begin) attribute.gteq(begin_bind)
elsif value.exclude_end? elsif value.exclude_end?
attribute.gteq(value.begin).and(attribute.lt(value.end)) attribute.gteq(begin_bind).and(attribute.lt(end_bind))
else else
attribute.between(value) attribute.between(RangeWithBinds.new(begin_bind, end_bind))
end end
end end
protected
attr_reader :predicate_builder
end end
end end
end end

View File

@ -16,6 +16,11 @@ module ActiveRecord
def with_cast_value(value) def with_cast_value(value)
QueryAttribute.new(name, value, type) QueryAttribute.new(name, value, type)
end end
def nil?
!value_before_type_cast.is_a?(StatementCache::Substitute) &&
(value_before_type_cast.nil? || value_for_database.nil?)
end
end end
end end
end end

View File

@ -76,31 +76,6 @@ module ActiveRecord
CODE CODE
end end
def bound_attributes
if limit_value
limit_bind = Attribute.with_cast_value(
"LIMIT".freeze,
connection.sanitize_limit(limit_value),
Type.default_value,
)
end
if offset_value
offset_bind = Attribute.with_cast_value(
"OFFSET".freeze,
offset_value.to_i,
Type.default_value,
)
end
connection.combine_bind_parameters(
from_clause: from_clause.binds,
join_clause: arel.bind_values,
where_clause: where_clause.binds,
having_clause: having_clause.binds,
limit: limit_bind,
offset: offset_bind,
)
end
alias extensions extending_values alias extensions extending_values
# Specify relationships to be included in the result set. For # Specify relationships to be included in the result set. For
@ -952,8 +927,22 @@ module ActiveRecord
arel.where(where_clause.ast) unless where_clause.empty? arel.where(where_clause.ast) unless where_clause.empty?
arel.having(having_clause.ast) unless having_clause.empty? arel.having(having_clause.ast) unless having_clause.empty?
arel.take(Arel::Nodes::BindParam.new) if limit_value if limit_value
arel.skip(Arel::Nodes::BindParam.new) if offset_value limit_attribute = Attribute.with_cast_value(
"LIMIT".freeze,
connection.sanitize_limit(limit_value),
Type.default_value,
)
arel.take(Arel::Nodes::BindParam.new(limit_attribute))
end
if offset_value
offset_attribute = Attribute.with_cast_value(
"OFFSET".freeze,
offset_value.to_i,
Type.default_value,
)
arel.skip(Arel::Nodes::BindParam.new(offset_attribute))
end
arel.group(*arel_columns(group_values.uniq.reject(&:blank?))) unless group_values.empty? arel.group(*arel_columns(group_values.uniq.reject(&:blank?))) unless group_values.empty?
build_order(arel) build_order(arel)
@ -1029,7 +1018,6 @@ module ActiveRecord
join_infos.each do |info| join_infos.each do |info|
info.joins.each { |join| manager.from(join) } info.joins.each { |join| manager.from(join) }
manager.bind_values.concat info.binds
end end
manager.join_sources.concat(join_list) manager.join_sources.concat(join_list)

View File

@ -3,31 +3,26 @@
module ActiveRecord module ActiveRecord
class Relation class Relation
class WhereClause # :nodoc: class WhereClause # :nodoc:
attr_reader :binds
delegate :any?, :empty?, to: :predicates delegate :any?, :empty?, to: :predicates
def initialize(predicates, binds) def initialize(predicates)
@predicates = predicates @predicates = predicates
@binds = binds
end end
def +(other) def +(other)
WhereClause.new( WhereClause.new(
predicates + other.predicates, predicates + other.predicates,
binds + other.binds,
) )
end end
def merge(other) def merge(other)
WhereClause.new( WhereClause.new(
predicates_unreferenced_by(other) + other.predicates, predicates_unreferenced_by(other) + other.predicates,
non_conflicting_binds(other) + other.binds,
) )
end end
def except(*columns) def except(*columns)
WhereClause.new(*except_predicates_and_binds(columns)) WhereClause.new(except_predicates(columns))
end end
def or(other) def or(other)
@ -38,7 +33,6 @@ module ActiveRecord
else else
WhereClause.new( WhereClause.new(
[ast.or(other.ast)], [ast.or(other.ast)],
binds + other.binds
) )
end end
end end
@ -51,17 +45,10 @@ module ActiveRecord
end end
end end
binds = self.binds.map { |attr| [attr.name, attr.value] }.to_h
equalities.map { |node| equalities.map { |node|
name = node.left.name.to_s name = node.left.name.to_s
[name, binds.fetch(name) { value = extract_node_value(node.right)
case node.right [name, value]
when Array then node.right.map(&:val)
when Arel::Nodes::Casted, Arel::Nodes::Quoted
node.right.val
end
}]
}.to_h }.to_h
end end
@ -71,20 +58,17 @@ module ActiveRecord
def ==(other) def ==(other)
other.is_a?(WhereClause) && other.is_a?(WhereClause) &&
predicates == other.predicates && predicates == other.predicates
binds == other.binds
end end
def invert def invert
WhereClause.new(inverted_predicates, binds) WhereClause.new(inverted_predicates)
end end
def self.empty def self.empty
@empty ||= new([], []) @empty ||= new([])
end end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected protected
attr_reader :predicates attr_reader :predicates
@ -108,12 +92,6 @@ module ActiveRecord
node.respond_to?(:operator) && node.operator == :== node.respond_to?(:operator) && node.operator == :==
end end
def non_conflicting_binds(other)
conflicts = referenced_columns & other.referenced_columns
conflicts.map! { |node| node.name.to_s }
binds.reject { |attr| conflicts.include?(attr.name) }
end
def inverted_predicates def inverted_predicates
predicates.map { |node| invert_predicate(node) } predicates.map { |node| invert_predicate(node) }
end end
@ -133,44 +111,22 @@ module ActiveRecord
end end
end end
def except_predicates_and_binds(columns) def except_predicates(columns)
except_binds = [] self.predicates.reject do |node|
binds_index = 0 case node
when Arel::Nodes::Between, Arel::Nodes::In, Arel::Nodes::NotIn, Arel::Nodes::Equality, Arel::Nodes::NotEqual, Arel::Nodes::LessThan, Arel::Nodes::LessThanOrEqual, Arel::Nodes::GreaterThan, Arel::Nodes::GreaterThanOrEqual
predicates = self.predicates.reject do |node| subrelation = (node.left.kind_of?(Arel::Attributes::Attribute) ? node.left : node.right)
binds_contains = node.grep(Arel::Nodes::BindParam).size if node.is_a?(Arel::Nodes::Node) columns.include?(subrelation.name.to_s)
except = \
case node
when Arel::Nodes::Between, Arel::Nodes::In, Arel::Nodes::NotIn, Arel::Nodes::Equality, Arel::Nodes::NotEqual, Arel::Nodes::LessThan, Arel::Nodes::LessThanOrEqual, Arel::Nodes::GreaterThan, Arel::Nodes::GreaterThanOrEqual
subrelation = (node.left.kind_of?(Arel::Attributes::Attribute) ? node.left : node.right)
columns.include?(subrelation.name.to_s)
end
if except && binds_contains > 0
(binds_index...(binds_index + binds_contains)).each do |i|
except_binds[i] = true
end
end end
binds_index += binds_contains if binds_contains
except
end end
binds = self.binds.reject.with_index do |_, i|
except_binds[i]
end
[predicates, binds]
end end
def predicates_with_wrapped_sql_literals def predicates_with_wrapped_sql_literals
non_empty_predicates.map do |node| non_empty_predicates.map do |node|
if Arel::Nodes::Equality === node case node
node when Arel::Nodes::SqlLiteral, ::String
else
wrap_sql_literal(node) wrap_sql_literal(node)
else node
end end
end end
end end
@ -186,6 +142,22 @@ module ActiveRecord
end end
Arel::Nodes::Grouping.new(node) Arel::Nodes::Grouping.new(node)
end end
def extract_node_value(node)
case node
when Array
node.map { |v| extract_node_value(v) }
when Arel::Nodes::Casted, Arel::Nodes::Quoted
node.val
when Arel::Nodes::BindParam
value = node.value
if value.respond_to?(:value_before_type_cast)
value.value_before_type_cast
else
value
end
end
end
end end
end end
end end

View File

@ -17,63 +17,19 @@ module ActiveRecord
attributes = klass.send(:expand_hash_conditions_for_aggregates, attributes) attributes = klass.send(:expand_hash_conditions_for_aggregates, attributes)
attributes.stringify_keys! attributes.stringify_keys!
if perform_case_sensitive?(options = other.last) parts = predicate_builder.build_from_hash(attributes)
parts, binds = build_for_case_sensitive(attributes, options)
else
attributes, binds = predicate_builder.create_binds(attributes)
parts = predicate_builder.build_from_hash(attributes)
end
when Arel::Nodes::Node when Arel::Nodes::Node
parts = [opts] parts = [opts]
else else
raise ArgumentError, "Unsupported argument type: #{opts} (#{opts.class})" raise ArgumentError, "Unsupported argument type: #{opts} (#{opts.class})"
end end
WhereClause.new(parts, binds || []) WhereClause.new(parts)
end end
# TODO Change this to private once we've dropped Ruby 2.2 support.
# Workaround for Ruby 2.2 "private attribute?" warning.
protected protected
attr_reader :klass, :predicate_builder attr_reader :klass, :predicate_builder
private
def perform_case_sensitive?(options)
options && options.key?(:case_sensitive)
end
def build_for_case_sensitive(attributes, options)
parts, binds = [], []
table = klass.arel_table
attributes.each do |attribute, value|
if reflection = klass._reflect_on_association(attribute)
attribute = reflection.foreign_key.to_s
value = value[reflection.klass.primary_key] unless value.nil?
end
if value.nil?
parts << table[attribute].eq(value)
else
column = klass.column_for_attribute(attribute)
binds << predicate_builder.send(:build_bind_attribute, attribute, value)
value = Arel::Nodes::BindParam.new
predicate = if options[:case_sensitive]
klass.connection.case_sensitive_comparison(table, attribute, column, value)
else
klass.connection.case_insensitive_comparison(table, attribute, column, value)
end
parts << predicate
end
end
[parts, binds]
end
end end
end end
end end

View File

@ -41,18 +41,20 @@ module ActiveRecord
end end
class PartialQuery < Query # :nodoc: class PartialQuery < Query # :nodoc:
def initialize(values) def initialize(arel)
@values = values @arel = arel
@indexes = values.each_with_index.find_all { |thing, i|
Arel::Nodes::BindParam === thing
}.map(&:last)
end end
def sql_for(binds, connection) def sql_for(binds, connection)
val = @values.dup val = @arel.dup
casted_binds = binds.map(&:value_for_database) val.grep(Arel::Nodes::BindParam) do |node|
@indexes.each { |i| val[i] = connection.quote(casted_binds.shift) } node.value = binds.shift
val.join if binds.empty?
break
end
end
sql, _ = connection.visitor.accept(val, connection.send(:collector)).value
sql
end end
end end
@ -90,9 +92,9 @@ module ActiveRecord
attr_reader :bind_map, :query_builder attr_reader :bind_map, :query_builder
def self.create(connection, block = Proc.new) def self.create(connection, block = Proc.new)
relation = block.call Params.new relation = block.call Params.new
bind_map = BindMap.new relation.bound_attributes query_builder, binds = connection.cacheable_query(self, relation.arel)
query_builder = connection.cacheable_query(self, relation.arel) bind_map = BindMap.new(binds)
new query_builder, bind_map new query_builder, bind_map
end end

View File

@ -52,7 +52,37 @@ module ActiveRecord
end end
def build_relation(klass, attribute, value) def build_relation(klass, attribute, value)
klass.unscoped.where!({ attribute => value }, options) if reflection = klass._reflect_on_association(attribute)
attribute = reflection.foreign_key
value = value.attributes[reflection.klass.primary_key] unless value.nil?
end
if value.nil?
return klass.unscoped.where!(attribute => value)
end
# the attribute may be an aliased attribute
if klass.attribute_alias?(attribute)
attribute = klass.attribute_alias(attribute)
end
attribute_name = attribute.to_s
table = klass.arel_table
column = klass.columns_hash[attribute_name]
cast_type = klass.type_for_attribute(attribute_name)
value = Relation::QueryAttribute.new(attribute_name, value, cast_type)
comparison = if !options[:case_sensitive]
# will use SQL LOWER function before comparison, unless it detects a case insensitive collation
klass.connection.case_insensitive_comparison(table, attribute, column, value)
else
klass.connection.case_sensitive_comparison(table, attribute, column, value)
end
klass.unscoped.tap do |scope|
parts = [comparison]
scope.where_clause += Relation::WhereClause.new(parts)
end
end end
def scope_relation(record, relation) def scope_relation(record, relation)

View File

@ -244,7 +244,7 @@ module ActiveRecord
def test_select_all_with_legacy_binds def test_select_all_with_legacy_binds
post = Post.create!(title: "foo", body: "bar") post = Post.create!(title: "foo", body: "bar")
expected = @connection.select_all("SELECT * FROM posts WHERE id = #{post.id}") expected = @connection.select_all("SELECT * FROM posts WHERE id = #{post.id}")
result = @connection.select_all("SELECT * FROM posts WHERE id = #{Arel::Nodes::BindParam.new.to_sql}", nil, [[nil, post.id]]) result = @connection.select_all("SELECT * FROM posts WHERE id = #{Arel::Nodes::BindParam.new(nil).to_sql}", nil, [[nil, post.id]])
assert_equal expected.to_hash, result.to_hash assert_equal expected.to_hash, result.to_hash
end end
end end
@ -253,7 +253,6 @@ module ActiveRecord
author = Author.create!(name: "john") author = Author.create!(name: "john")
Post.create!(author: author, title: "foo", body: "bar") Post.create!(author: author, title: "foo", body: "bar")
query = author.posts.where(title: "foo").select(:title) query = author.posts.where(title: "foo").select(:title)
assert_equal({ "title" => "foo" }, @connection.select_one(query.arel, nil, query.bound_attributes))
assert_equal({ "title" => "foo" }, @connection.select_one(query)) assert_equal({ "title" => "foo" }, @connection.select_one(query))
assert @connection.select_all(query).is_a?(ActiveRecord::Result) assert @connection.select_all(query).is_a?(ActiveRecord::Result)
assert_equal "foo", @connection.select_value(query) assert_equal "foo", @connection.select_value(query)
@ -263,7 +262,6 @@ module ActiveRecord
def test_select_methods_passing_a_relation def test_select_methods_passing_a_relation
Post.create!(title: "foo", body: "bar") Post.create!(title: "foo", body: "bar")
query = Post.where(title: "foo").select(:title) query = Post.where(title: "foo").select(:title)
assert_equal({ "title" => "foo" }, @connection.select_one(query.arel, nil, query.bound_attributes))
assert_equal({ "title" => "foo" }, @connection.select_one(query)) assert_equal({ "title" => "foo" }, @connection.select_one(query))
assert @connection.select_all(query).is_a?(ActiveRecord::Result) assert @connection.select_all(query).is_a?(ActiveRecord::Result)
assert_equal "foo", @connection.select_value(query) assert_equal "foo", @connection.select_value(query)

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
require "cases/helper"
require "models/post"
require "models/author"
module ActiveRecord
module Associations
class AssociationScopeTest < ActiveRecord::TestCase
test "does not duplicate conditions" do
scope = AssociationScope.scope(Author.new.association(:welcome_posts))
binds = scope.where_clause.binds.map(&:value)
assert_equal binds.uniq, binds
end
end
end
end

View File

@ -41,7 +41,7 @@ if ActiveRecord::Base.connection.prepared_statements
end end
def test_binds_are_logged def test_binds_are_logged
sub = Arel::Nodes::BindParam.new sub = Arel::Nodes::BindParam.new(1)
binds = [Relation::QueryAttribute.new("id", 1, Type::Value.new)] binds = [Relation::QueryAttribute.new("id", 1, Type::Value.new)]
sql = "select * from topics where id = #{sub.to_sql}" sql = "select * from topics where id = #{sub.to_sql}"

View File

@ -420,7 +420,7 @@ class InheritanceTest < ActiveRecord::TestCase
def test_eager_load_belongs_to_primary_key_quoting def test_eager_load_belongs_to_primary_key_quoting
con = Account.connection con = Account.connection
bind_param = Arel::Nodes::BindParam.new bind_param = Arel::Nodes::BindParam.new(nil)
assert_sql(/#{con.quote_table_name('companies')}\.#{con.quote_column_name('id')} = (?:#{Regexp.quote(bind_param.to_sql)}|1)/) do assert_sql(/#{con.quote_table_name('companies')}\.#{con.quote_column_name('id')} = (?:#{Regexp.quote(bind_param.to_sql)}|1)/) do
Account.all.merge!(includes: :firm).find(1) Account.all.merge!(includes: :firm).find(1)
end end

View File

@ -81,13 +81,11 @@ class RelationMergingTest < ActiveRecord::TestCase
end end
test "merge collapses wheres from the LHS only" do test "merge collapses wheres from the LHS only" do
left = Post.where(title: "omg").where(comments_count: 1) left = Post.where(title: "omg").where(comments_count: 1)
right = Post.where(title: "wtf").where(title: "bbq") right = Post.where(title: "wtf").where(title: "bbq")
expected = [left.bound_attributes[1]] + right.bound_attributes merged = left.merge(right)
merged = left.merge(right)
assert_equal expected, merged.bound_attributes
assert_not_includes merged.to_sql, "omg" assert_not_includes merged.to_sql, "omg"
assert_includes merged.to_sql, "wtf" assert_includes merged.to_sql, "wtf"
assert_includes merged.to_sql, "bbq" assert_includes merged.to_sql, "bbq"

View File

@ -27,7 +27,7 @@ module ActiveRecord
end end
def test_association_not_eq def test_association_not_eq
expected = Arel::Nodes::Grouping.new(Comment.arel_table[@name].not_eq(Arel::Nodes::BindParam.new)) expected = Comment.arel_table[@name].not_eq(Arel::Nodes::BindParam.new(1))
relation = Post.joins(:comments).where.not(comments: { title: "hello" }) relation = Post.joins(:comments).where.not(comments: { title: "hello" })
assert_equal(expected.to_sql, relation.where_clause.ast.to_sql) assert_equal(expected.to_sql, relation.where_clause.ast.to_sql)
end end

View File

@ -5,76 +5,82 @@ require "cases/helper"
class ActiveRecord::Relation class ActiveRecord::Relation
class WhereClauseTest < ActiveRecord::TestCase class WhereClauseTest < ActiveRecord::TestCase
test "+ combines two where clauses" do test "+ combines two where clauses" do
first_clause = WhereClause.new([table["id"].eq(bind_param)], [["id", 1]]) first_clause = WhereClause.new([table["id"].eq(bind_param(1))])
second_clause = WhereClause.new([table["name"].eq(bind_param)], [["name", "Sean"]]) second_clause = WhereClause.new([table["name"].eq(bind_param("Sean"))])
combined = WhereClause.new( combined = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)], [table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean"))],
[["id", 1], ["name", "Sean"]],
) )
assert_equal combined, first_clause + second_clause assert_equal combined, first_clause + second_clause
end end
test "+ is associative, but not commutative" do test "+ is associative, but not commutative" do
a = WhereClause.new(["a"], ["bind a"]) a = WhereClause.new(["a"])
b = WhereClause.new(["b"], ["bind b"]) b = WhereClause.new(["b"])
c = WhereClause.new(["c"], ["bind c"]) c = WhereClause.new(["c"])
assert_equal a + (b + c), (a + b) + c assert_equal a + (b + c), (a + b) + c
assert_not_equal a + b, b + a assert_not_equal a + b, b + a
end end
test "an empty where clause is the identity value for +" do test "an empty where clause is the identity value for +" do
clause = WhereClause.new([table["id"].eq(bind_param)], [["id", 1]]) clause = WhereClause.new([table["id"].eq(bind_param(1))])
assert_equal clause, clause + WhereClause.empty assert_equal clause, clause + WhereClause.empty
end end
test "merge combines two where clauses" do test "merge combines two where clauses" do
a = WhereClause.new([table["id"].eq(1)], []) a = WhereClause.new([table["id"].eq(1)])
b = WhereClause.new([table["name"].eq("Sean")], []) b = WhereClause.new([table["name"].eq("Sean")])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")], []) expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")])
assert_equal expected, a.merge(b) assert_equal expected, a.merge(b)
end end
test "merge keeps the right side, when two equality clauses reference the same column" do test "merge keeps the right side, when two equality clauses reference the same column" do
a = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")], []) a = WhereClause.new([table["id"].eq(1), table["name"].eq("Sean")])
b = WhereClause.new([table["name"].eq("Jim")], []) b = WhereClause.new([table["name"].eq("Jim")])
expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Jim")], []) expected = WhereClause.new([table["id"].eq(1), table["name"].eq("Jim")])
assert_equal expected, a.merge(b) assert_equal expected, a.merge(b)
end end
test "merge removes bind parameters matching overlapping equality clauses" do test "merge removes bind parameters matching overlapping equality clauses" do
a = WhereClause.new( a = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)], [table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean"))],
[attribute("id", 1), attribute("name", "Sean")],
) )
b = WhereClause.new( b = WhereClause.new(
[table["name"].eq(bind_param)], [table["name"].eq(bind_param("Jim"))],
[attribute("name", "Jim")]
) )
expected = WhereClause.new( expected = WhereClause.new(
[table["id"].eq(bind_param), table["name"].eq(bind_param)], [table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Jim"))],
[attribute("id", 1), attribute("name", "Jim")],
) )
assert_equal expected, a.merge(b) assert_equal expected, a.merge(b)
end end
test "merge allows for columns with the same name from different tables" do test "merge allows for columns with the same name from different tables" do
skip "This is not possible as of 4.2, and the binds do not yet contain sufficient information for this to happen" table2 = Arel::Table.new("table2")
# We might be able to change the implementation to remove conflicts by index, rather than column name a = WhereClause.new(
[table["id"].eq(bind_param(1)), table2["id"].eq(bind_param(2))],
)
b = WhereClause.new(
[table["id"].eq(bind_param(3))],
)
expected = WhereClause.new(
[table2["id"].eq(bind_param(2)), table["id"].eq(bind_param(3))],
)
assert_equal expected, a.merge(b)
end end
test "a clause knows if it is empty" do test "a clause knows if it is empty" do
assert WhereClause.empty.empty? assert WhereClause.empty.empty?
assert_not WhereClause.new(["anything"], []).empty? assert_not WhereClause.new(["anything"]).empty?
end end
test "invert cannot handle nil" do test "invert cannot handle nil" do
where_clause = WhereClause.new([nil], []) where_clause = WhereClause.new([nil])
assert_raises ArgumentError do assert_raises ArgumentError do
where_clause.invert where_clause.invert
@ -88,13 +94,13 @@ class ActiveRecord::Relation
table["id"].eq(1), table["id"].eq(1),
"sql literal", "sql literal",
random_object random_object
], []) ])
expected = WhereClause.new([ expected = WhereClause.new([
table["id"].not_in([1, 2, 3]), table["id"].not_in([1, 2, 3]),
table["id"].not_eq(1), table["id"].not_eq(1),
Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new("sql literal")), Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new("sql literal")),
Arel::Nodes::Not.new(random_object) Arel::Nodes::Not.new(random_object)
], []) ])
assert_equal expected, original.invert assert_equal expected, original.invert
end end
@ -102,20 +108,17 @@ class ActiveRecord::Relation
test "except removes binary predicates referencing a given column" do test "except removes binary predicates referencing a given column" do
where_clause = WhereClause.new([ where_clause = WhereClause.new([
table["id"].in([1, 2, 3]), table["id"].in([1, 2, 3]),
table["name"].eq(bind_param), table["name"].eq(bind_param("Sean")),
table["age"].gteq(bind_param), table["age"].gteq(bind_param(30)),
], [
attribute("name", "Sean"),
attribute("age", 30),
]) ])
expected = WhereClause.new([table["age"].gteq(bind_param)], [attribute("age", 30)]) expected = WhereClause.new([table["age"].gteq(bind_param(30))])
assert_equal expected, where_clause.except("id", "name") assert_equal expected, where_clause.except("id", "name")
end end
test "except jumps over unhandled binds (like with OR) correctly" do test "except jumps over unhandled binds (like with OR) correctly" do
wcs = (0..9).map do |i| wcs = (0..9).map do |i|
WhereClause.new([table["id#{i}"].eq(bind_param)], [attribute("id#{i}", i)]) WhereClause.new([table["id#{i}"].eq(bind_param(i))])
end end
wc = wcs[0] + wcs[1] + wcs[2].or(wcs[3]) + wcs[4] + wcs[5] + wcs[6].or(wcs[7]) + wcs[8] + wcs[9] wc = wcs[0] + wcs[1] + wcs[2].or(wcs[3]) + wcs[4] + wcs[5] + wcs[6].or(wcs[7]) + wcs[8] + wcs[9]
@ -123,18 +126,15 @@ class ActiveRecord::Relation
expected = wcs[0] + wcs[2].or(wcs[3]) + wcs[5] + wcs[6].or(wcs[7]) + wcs[9] expected = wcs[0] + wcs[2].or(wcs[3]) + wcs[5] + wcs[6].or(wcs[7]) + wcs[9]
actual = wc.except("id1", "id2", "id4", "id7", "id8") actual = wc.except("id1", "id2", "id4", "id7", "id8")
# Easier to read than the inspect of where_clause
assert_equal expected.ast.to_sql, actual.ast.to_sql
assert_equal expected.binds.map(&:value), actual.binds.map(&:value)
assert_equal expected, actual assert_equal expected, actual
end end
test "ast groups its predicates with AND" do test "ast groups its predicates with AND" do
predicates = [ predicates = [
table["id"].in([1, 2, 3]), table["id"].in([1, 2, 3]),
table["name"].eq(bind_param), table["name"].eq(bind_param(nil)),
] ]
where_clause = WhereClause.new(predicates, []) where_clause = WhereClause.new(predicates)
expected = Arel::Nodes::And.new(predicates) expected = Arel::Nodes::And.new(predicates)
assert_equal expected, where_clause.ast assert_equal expected, where_clause.ast
@ -146,38 +146,36 @@ class ActiveRecord::Relation
table["id"].in([1, 2, 3]), table["id"].in([1, 2, 3]),
"foo = bar", "foo = bar",
random_object, random_object,
], []) ])
expected = Arel::Nodes::And.new([ expected = Arel::Nodes::And.new([
table["id"].in([1, 2, 3]), table["id"].in([1, 2, 3]),
Arel::Nodes::Grouping.new(Arel.sql("foo = bar")), Arel::Nodes::Grouping.new(Arel.sql("foo = bar")),
Arel::Nodes::Grouping.new(random_object), random_object,
]) ])
assert_equal expected, where_clause.ast assert_equal expected, where_clause.ast
end end
test "ast removes any empty strings" do test "ast removes any empty strings" do
where_clause = WhereClause.new([table["id"].in([1, 2, 3])], []) where_clause = WhereClause.new([table["id"].in([1, 2, 3])])
where_clause_with_empty = WhereClause.new([table["id"].in([1, 2, 3]), ""], []) where_clause_with_empty = WhereClause.new([table["id"].in([1, 2, 3]), ""])
assert_equal where_clause.ast, where_clause_with_empty.ast assert_equal where_clause.ast, where_clause_with_empty.ast
end end
test "or joins the two clauses using OR" do test "or joins the two clauses using OR" do
where_clause = WhereClause.new([table["id"].eq(bind_param)], [attribute("id", 1)]) where_clause = WhereClause.new([table["id"].eq(bind_param(1))])
other_clause = WhereClause.new([table["name"].eq(bind_param)], [attribute("name", "Sean")]) other_clause = WhereClause.new([table["name"].eq(bind_param("Sean"))])
expected_ast = expected_ast =
Arel::Nodes::Grouping.new( Arel::Nodes::Grouping.new(
Arel::Nodes::Or.new(table["id"].eq(bind_param), table["name"].eq(bind_param)) Arel::Nodes::Or.new(table["id"].eq(bind_param(1)), table["name"].eq(bind_param("Sean")))
) )
expected_binds = where_clause.binds + other_clause.binds
assert_equal expected_ast.to_sql, where_clause.or(other_clause).ast.to_sql assert_equal expected_ast.to_sql, where_clause.or(other_clause).ast.to_sql
assert_equal expected_binds, where_clause.or(other_clause).binds
end end
test "or returns an empty where clause when either side is empty" do test "or returns an empty where clause when either side is empty" do
where_clause = WhereClause.new([table["id"].eq(bind_param)], [attribute("id", 1)]) where_clause = WhereClause.new([table["id"].eq(bind_param(1))])
assert_equal WhereClause.empty, where_clause.or(WhereClause.empty) assert_equal WhereClause.empty, where_clause.or(WhereClause.empty)
assert_equal WhereClause.empty, WhereClause.empty.or(where_clause) assert_equal WhereClause.empty, WhereClause.empty.or(where_clause)
@ -189,12 +187,8 @@ class ActiveRecord::Relation
Arel::Table.new("table") Arel::Table.new("table")
end end
def bind_param def bind_param(value)
Arel::Nodes::BindParam.new Arel::Nodes::BindParam.new(value)
end
def attribute(name, value)
ActiveRecord::Attribute.with_cast_value(name, value, ActiveRecord::Type::Value.new)
end end
end end
end end

View File

@ -127,7 +127,7 @@ module ActiveRecord
car = cars(:honda) car = cars(:honda)
expected = [price_estimates(:diamond), price_estimates(:sapphire_1), price_estimates(:sapphire_2), price_estimates(:honda)].sort expected = [price_estimates(:diamond), price_estimates(:sapphire_1), price_estimates(:sapphire_2), price_estimates(:honda)].sort
actual = PriceEstimate.where(estimate_of: [treasure_1, treasure_2, car]).to_a.sort actual = PriceEstimate.where(estimate_of: [treasure_1, treasure_2, car]).to_a.sort
assert_equal expected, actual assert_equal expected, actual
end end

View File

@ -185,7 +185,7 @@ module ActiveRecord
relation = Relation.new(klass, :b, nil) relation = Relation.new(klass, :b, nil)
relation.merge!(where: ["foo = ?", "bar"]) relation.merge!(where: ["foo = ?", "bar"])
assert_equal Relation::WhereClause.new(["foo = bar"], []), relation.where_clause assert_equal Relation::WhereClause.new(["foo = bar"]), relation.where_clause
end end
def test_merging_readonly_false def test_merging_readonly_false

View File

@ -1761,7 +1761,7 @@ class RelationTest < ActiveRecord::TestCase
test "relations with cached arel can't be mutated [internal API]" do test "relations with cached arel can't be mutated [internal API]" do
relation = Post.all relation = Post.all
relation.count relation.arel
assert_raises(ActiveRecord::ImmutableRelation) { relation.limit!(5) } assert_raises(ActiveRecord::ImmutableRelation) { relation.limit!(5) }
assert_raises(ActiveRecord::ImmutableRelation) { relation.where!("1 = 2") } assert_raises(ActiveRecord::ImmutableRelation) { relation.where!("1 = 2") }
@ -1860,33 +1860,6 @@ class RelationTest < ActiveRecord::TestCase
assert_equal 1, posts.unscope(where: :body).count assert_equal 1, posts.unscope(where: :body).count
end end
def test_unscope_removes_binds
left = Post.where(id: 20)
assert_equal 1, left.bound_attributes.length
relation = left.unscope(where: :id)
assert_equal [], relation.bound_attributes
end
def test_merging_removes_rhs_binds
left = Post.where(id: 20)
right = Post.where(id: [1, 2, 3, 4])
assert_equal 1, left.bound_attributes.length
merged = left.merge(right)
assert_equal [], merged.bound_attributes
end
def test_merging_keeps_lhs_binds
right = Post.where(id: 20)
left = Post.where(id: 10)
merged = left.merge(right)
assert_equal [20], merged.bound_attributes.map(&:value)
end
def test_locked_should_not_build_arel def test_locked_should_not_build_arel
posts = Post.locked posts = Post.locked
assert posts.locked? assert posts.locked?
@ -1897,24 +1870,6 @@ class RelationTest < ActiveRecord::TestCase
assert_equal "Thank you for the welcome,Thank you again for the welcome", Post.first.comments.join(",") assert_equal "Thank you for the welcome,Thank you again for the welcome", Post.first.comments.join(",")
end end
def test_connection_adapters_can_reorder_binds
posts = Post.limit(1).offset(2)
stubbed_connection = Post.connection.dup
def stubbed_connection.combine_bind_parameters(**kwargs)
offset = kwargs[:offset]
kwargs[:offset] = kwargs[:limit]
kwargs[:limit] = offset
super(**kwargs)
end
posts.define_singleton_method(:connection) do
stubbed_connection
end
assert_equal 2, posts.to_a.length
end
test "#skip_query_cache!" do test "#skip_query_cache!" do
Post.cache do Post.cache do
assert_queries(1) do assert_queries(1) do