Merge pull request #51174 from Shopify/connection-less-quoting

Don't require an active connection for table and column quoting
This commit is contained in:
Jean Boussier 2024-02-27 11:34:32 +01:00 committed by GitHub
commit af85f74418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 307 additions and 225 deletions

View File

@ -92,7 +92,7 @@ module ActiveRecord
# Returns a quoted version of the primary key name, used to construct
# SQL statements.
def quoted_primary_key
@quoted_primary_key ||= connection.quote_column_name(primary_key)
@quoted_primary_key ||= adapter_class.quote_column_name(primary_key)
end
def reset_primary_key # :nodoc:

View File

@ -7,6 +7,67 @@ module ActiveRecord
module ConnectionAdapters # :nodoc:
# = Active Record Connection Adapters \Quoting
module Quoting
extend ActiveSupport::Concern
module ClassMethods # :nodoc:
# Regexp for column names (with or without a table name prefix).
# Matches the following:
#
# "#{table_name}.#{column_name}"
# "#{column_name}"
def column_name_matcher
/
\A
(
(?:
# table_name.column_name | function(one or no argument)
((?:\w+\.)?\w+ | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+\w+)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
# Regexp for column names with order (with or without a table name prefix,
# with or without various order modifiers). Matches the following:
#
# "#{table_name}.#{column_name}"
# "#{table_name}.#{column_name} #{direction}"
# "#{table_name}.#{column_name} #{direction} NULLS FIRST"
# "#{table_name}.#{column_name} NULLS LAST"
# "#{column_name}"
# "#{column_name} #{direction}"
# "#{column_name} #{direction} NULLS FIRST"
# "#{column_name} NULLS LAST"
def column_name_with_order_matcher
/
\A
(
(?:
# table_name.column_name | function(one or no argument)
((?:\w+\.)?\w+ | \w+\((?:|\g<2>)\))
)
(?:\s+ASC|\s+DESC)?
(?:\s+NULLS\s+(?:FIRST|LAST))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
# Quotes the column name. Must be implemented by subclasses
def quote_column_name(column_name)
raise NotImplementedError
end
# Quotes the table name. Defaults to column name quoting.
def quote_table_name(table_name)
quote_column_name(table_name)
end
end
# Quotes the column value to help prevent
# {SQL injection attacks}[https://en.wikipedia.org/wiki/SQL_injection].
def quote(value)
@ -71,14 +132,14 @@ module ActiveRecord
s.gsub("\\", '\&\&').gsub("'", "''") # ' (for ruby-mode)
end
# Quotes the column name. Defaults to no quoting.
# Quotes the column name.
def quote_column_name(column_name)
column_name.to_s
self.class.quote_column_name(column_name)
end
# Quotes the table name. Defaults to column name quoting.
# Quotes the table name.
def quote_table_name(table_name)
quote_column_name(table_name)
self.class.quote_table_name(table_name)
end
# Override to return the quoted table name for assignment. Defaults to
@ -159,59 +220,6 @@ module ActiveRecord
comment
end
def column_name_matcher # :nodoc:
COLUMN_NAME
end
def column_name_with_order_matcher # :nodoc:
COLUMN_NAME_WITH_ORDER
end
# Regexp for column names (with or without a table name prefix).
# Matches the following:
#
# "#{table_name}.#{column_name}"
# "#{column_name}"
COLUMN_NAME = /
\A
(
(?:
# table_name.column_name | function(one or no argument)
((?:\w+\.)?\w+ | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+\w+)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
# Regexp for column names with order (with or without a table name prefix,
# with or without various order modifiers). Matches the following:
#
# "#{table_name}.#{column_name}"
# "#{table_name}.#{column_name} #{direction}"
# "#{table_name}.#{column_name} #{direction} NULLS FIRST"
# "#{table_name}.#{column_name} NULLS LAST"
# "#{column_name}"
# "#{column_name} #{direction}"
# "#{column_name} #{direction} NULLS FIRST"
# "#{column_name} NULLS LAST"
COLUMN_NAME_WITH_ORDER = /
\A
(
(?:
# table_name.column_name | function(one or no argument)
((?:\w+\.)?\w+ | \w+\((?:|\g<2>)\))
)
(?:\s+ASC|\s+DESC)?
(?:\s+NULLS\s+(?:FIRST|LAST))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
private_constant :COLUMN_NAME, :COLUMN_NAME_WITH_ORDER
private
def type_casted_binds(binds)
binds.map do |value|

View File

@ -6,9 +6,52 @@ module ActiveRecord
module ConnectionAdapters
module MySQL
module Quoting # :nodoc:
extend ActiveSupport::Concern
QUOTED_COLUMN_NAMES = Concurrent::Map.new # :nodoc:
QUOTED_TABLE_NAMES = Concurrent::Map.new # :nodoc:
module ClassMethods # :nodoc:
def column_name_matcher
/
\A
(
(?:
# `table_name`.`column_name` | function(one or no argument)
((?:\w+\.|`\w+`\.)?(?:\w+|`\w+`) | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+(?:\w+|`\w+`))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
def column_name_with_order_matcher
/
\A
(
(?:
# `table_name`.`column_name` | function(one or no argument)
((?:\w+\.|`\w+`\.)?(?:\w+|`\w+`) | \w+\((?:|\g<2>)\))
)
(?:\s+COLLATE\s+(?:\w+|"\w+"))?
(?:\s+ASC|\s+DESC)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
def quote_column_name(name)
QUOTED_COLUMN_NAMES[name] ||= "`#{name.to_s.gsub('`', '``')}`".freeze
end
def quote_table_name(name)
QUOTED_TABLE_NAMES[name] ||= "`#{name.to_s.gsub('`', '``').gsub(".", "`.`")}`".freeze
end
end
def cast_bound_value(value)
case value
when Rational
@ -26,14 +69,6 @@ module ActiveRecord
end
end
def quote_column_name(name)
QUOTED_COLUMN_NAMES[name] ||= "`#{super.gsub('`', '``')}`".freeze
end
def quote_table_name(name)
QUOTED_TABLE_NAMES[name] ||= -super.gsub(".", "`.`").freeze
end
def unquoted_true
1
end
@ -81,43 +116,6 @@ module ActiveRecord
super
end
end
def column_name_matcher
COLUMN_NAME
end
def column_name_with_order_matcher
COLUMN_NAME_WITH_ORDER
end
COLUMN_NAME = /
\A
(
(?:
# `table_name`.`column_name` | function(one or no argument)
((?:\w+\.|`\w+`\.)?(?:\w+|`\w+`) | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+(?:\w+|`\w+`))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
COLUMN_NAME_WITH_ORDER = /
\A
(
(?:
# `table_name`.`column_name` | function(one or no argument)
((?:\w+\.|`\w+`\.)?(?:\w+|`\w+`) | \w+\((?:|\g<2>)\))
)
(?:\s+COLLATE\s+(?:\w+|"\w+"))?
(?:\s+ASC|\s+DESC)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
private_constant :COLUMN_NAME, :COLUMN_NAME_WITH_ORDER
end
end
end

View File

@ -4,9 +4,62 @@ module ActiveRecord
module ConnectionAdapters
module PostgreSQL
module Quoting
extend ActiveSupport::Concern
QUOTED_COLUMN_NAMES = Concurrent::Map.new # :nodoc:
QUOTED_TABLE_NAMES = Concurrent::Map.new # :nodoc:
module ClassMethods # :nodoc:
def column_name_matcher
/
\A
(
(?:
# "schema_name"."table_name"."column_name"::type_name | function(one or no argument)::type_name
((?:\w+\.|"\w+"\.){,2}(?:\w+|"\w+")(?:::\w+)? | \w+\((?:|\g<2>)\)(?:::\w+)?)
)
(?:(?:\s+AS)?\s+(?:\w+|"\w+"))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
def column_name_with_order_matcher
/
\A
(
(?:
# "schema_name"."table_name"."column_name"::type_name | function(one or no argument)::type_name
((?:\w+\.|"\w+"\.){,2}(?:\w+|"\w+")(?:::\w+)? | \w+\((?:|\g<2>)\)(?:::\w+)?)
)
(?:\s+COLLATE\s+"\w+")?
(?:\s+ASC|\s+DESC)?
(?:\s+NULLS\s+(?:FIRST|LAST))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
# Quotes column names for use in SQL queries.
def quote_column_name(name) # :nodoc:
QUOTED_COLUMN_NAMES[name] ||= PG::Connection.quote_ident(name.to_s).freeze
end
# Checks the following cases:
#
# - table_name
# - "table.name"
# - schema_name.table_name
# - schema_name."table.name"
# - "schema.name".table_name
# - "schema.name"."table.name"
def quote_table_name(name) # :nodoc:
QUOTED_TABLE_NAMES[name] ||= Utils.extract_schema_qualified_name(name.to_s).quoted.freeze
end
end
class IntegerOutOf64BitRange < StandardError
def initialize(msg)
super(msg)
@ -77,29 +130,14 @@ module ActiveRecord
end
end
# Checks the following cases:
#
# - table_name
# - "table.name"
# - schema_name.table_name
# - schema_name."table.name"
# - "schema.name".table_name
# - "schema.name"."table.name"
def quote_table_name(name) # :nodoc:
QUOTED_TABLE_NAMES[name] ||= -Utils.extract_schema_qualified_name(name.to_s).quoted.freeze
end
def quote_table_name_for_assignment(table, attr)
quote_column_name(attr)
end
# Quotes column names for use in SQL queries.
def quote_column_name(name) # :nodoc:
QUOTED_COLUMN_NAMES[name] ||= PG::Connection.quote_ident(super).freeze
end
# Quotes schema names for use in SQL queries.
alias_method :quote_schema_name, :quote_column_name
def quote_schema_name(schema_name)
quote_column_name(schema_name)
end
# Quote date/time values for use in SQL input.
def quoted_date(value) # :nodoc:
@ -153,44 +191,6 @@ module ActiveRecord
type_map.lookup(column.oid, column.fmod, column.sql_type)
end
def column_name_matcher
COLUMN_NAME
end
def column_name_with_order_matcher
COLUMN_NAME_WITH_ORDER
end
COLUMN_NAME = /
\A
(
(?:
# "schema_name"."table_name"."column_name"::type_name | function(one or no argument)::type_name
((?:\w+\.|"\w+"\.){,2}(?:\w+|"\w+")(?:::\w+)? | \w+\((?:|\g<2>)\)(?:::\w+)?)
)
(?:(?:\s+AS)?\s+(?:\w+|"\w+"))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
COLUMN_NAME_WITH_ORDER = /
\A
(
(?:
# "schema_name"."table_name"."column_name"::type_name | function(one or no argument)::type_name
((?:\w+\.|"\w+"\.){,2}(?:\w+|"\w+")(?:::\w+)? | \w+\((?:|\g<2>)\)(?:::\w+)?)
)
(?:\s+COLLATE\s+"\w+")?
(?:\s+ASC|\s+DESC)?
(?:\s+NULLS\s+(?:FIRST|LAST))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
private_constant :COLUMN_NAME, :COLUMN_NAME_WITH_ORDER
private
def lookup_cast_type(sql_type)
super(query_value("SELECT #{quote(sql_type)}::regtype::oid", "SCHEMA").to_i)

View File

@ -4,9 +4,52 @@ module ActiveRecord
module ConnectionAdapters
module SQLite3
module Quoting # :nodoc:
extend ActiveSupport::Concern
QUOTED_COLUMN_NAMES = Concurrent::Map.new # :nodoc:
QUOTED_TABLE_NAMES = Concurrent::Map.new # :nodoc:
module ClassMethods # :nodoc:
def column_name_matcher
/
\A
(
(?:
# "table_name"."column_name" | function(one or no argument)
((?:\w+\.|"\w+"\.)?(?:\w+|"\w+") | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+(?:\w+|"\w+"))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
def column_name_with_order_matcher
/
\A
(
(?:
# "table_name"."column_name" | function(one or no argument)
((?:\w+\.|"\w+"\.)?(?:\w+|"\w+") | \w+\((?:|\g<2>)\))
)
(?:\s+COLLATE\s+(?:\w+|"\w+"))?
(?:\s+ASC|\s+DESC)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
end
def quote_column_name(name)
QUOTED_COLUMN_NAMES[name] ||= %Q("#{name.to_s.gsub('"', '""')}").freeze
end
def quote_table_name(name)
QUOTED_TABLE_NAMES[name] ||= %Q("#{name.to_s.gsub('"', '""').gsub(".", "\".\"")}").freeze
end
end
def quote_string(s)
::SQLite3::Database.quote(s)
end
@ -15,14 +58,6 @@ module ActiveRecord
quote_column_name(attr)
end
def quote_table_name(name)
QUOTED_TABLE_NAMES[name] ||= -super.gsub(".", "\".\"").freeze
end
def quote_column_name(name)
QUOTED_COLUMN_NAMES[name] ||= %Q("#{super.gsub('"', '""')}").freeze
end
def quoted_time(value)
value = value.change(year: 2000, month: 1, day: 1)
quoted_date(value).sub(/\A\d\d\d\d-\d\d-\d\d /, "2000-01-01 ")
@ -75,43 +110,6 @@ module ActiveRecord
super
end
end
def column_name_matcher
COLUMN_NAME
end
def column_name_with_order_matcher
COLUMN_NAME_WITH_ORDER
end
COLUMN_NAME = /
\A
(
(?:
# "table_name"."column_name" | function(one or no argument)
((?:\w+\.|"\w+"\.)?(?:\w+|"\w+") | \w+\((?:|\g<2>)\))
)
(?:(?:\s+AS)?\s+(?:\w+|"\w+"))?
)
(?:\s*,\s*\g<1>)*
\z
/ix
COLUMN_NAME_WITH_ORDER = /
\A
(
(?:
# "table_name"."column_name" | function(one or no argument)
((?:\w+\.|"\w+"\.)?(?:\w+|"\w+") | \w+\((?:|\g<2>)\))
)
(?:\s+COLLATE\s+(?:\w+|"\w+"))?
(?:\s+ASC|\s+DESC)?
)
(?:\s*,\s*\g<1>)*
\z
/ix
private_constant :COLUMN_NAME, :COLUMN_NAME_WITH_ORDER
end
end
end

View File

@ -286,6 +286,10 @@ module ActiveRecord
connection_pool.db_config
end
def adapter_class # :nodoc:
connection_pool.db_config.adapter_class
end
def connection_pool
connection_handler.retrieve_connection_pool(connection_specification_name, role: current_role, shard: current_shard, strict: true)
end

View File

@ -279,7 +279,7 @@ module ActiveRecord
# Returns a quoted version of the table name, used to construct SQL statements.
def quoted_table_name
@quoted_table_name ||= connection.quote_table_name(table_name)
@quoted_table_name ||= adapter_class.quote_table_name(table_name)
end
# Computes the table name, (re)sets it internally, and returns it.

View File

@ -469,7 +469,7 @@ module ActiveRecord
collection = eager_loading? ? apply_join_dependency : self
column = connection.visitor.compile(table[timestamp_column])
select_values = "COUNT(*) AS #{connection.quote_column_name("size")}, MAX(%s) AS timestamp"
select_values = "COUNT(*) AS #{adapter_class.quote_column_name("size")}, MAX(%s) AS timestamp"
if collection.has_limit_or_offset?
query = collection.select("#{column} AS collection_cache_key_timestamp")

View File

@ -511,13 +511,13 @@ module ActiveRecord
column = aggregate_column(column_name)
column_alias = column_alias_tracker.alias_for("#{operation} #{column_name.to_s.downcase}")
select_value = operation_over_aggregate_column(column, operation, distinct)
select_value.as(connection.quote_column_name(column_alias))
select_value.as(adapter_class.quote_column_name(column_alias))
select_values = [select_value]
select_values += self.select_values unless having_clause.empty?
select_values.concat group_columns.map { |aliaz, field|
aliaz = connection.quote_column_name(aliaz)
aliaz = adapter_class.quote_column_name(aliaz)
if field.respond_to?(:as)
field.as(aliaz)
else

View File

@ -634,7 +634,7 @@ module ActiveRecord
# # END ASC
#
def in_order_of(column, values)
klass.disallow_raw_sql!([column], permit: connection.column_name_with_order_matcher)
klass.disallow_raw_sql!([column], permit: model.adapter_class.column_name_with_order_matcher)
return spawn.none! if values.empty?
references = column_references([column])
@ -1848,7 +1848,7 @@ module ActiveRecord
case field
when Symbol
arel_column(field.to_s) do |attr_name|
connection.quote_table_name(attr_name)
adapter_class.quote_table_name(attr_name)
end
when String
arel_column(field, &:itself)
@ -1878,7 +1878,7 @@ module ActiveRecord
def table_name_matches?(from)
table_name = Regexp.escape(table.name)
quoted_table_name = Regexp.escape(connection.quote_table_name(table.name))
quoted_table_name = Regexp.escape(adapter_class.quote_table_name(table.name))
/(?:\A|(?<!FROM)\s)(?:\b#{table_name}\b|#{quoted_table_name})(?!\.)/i.match?(from.to_s)
end
@ -1951,7 +1951,7 @@ module ActiveRecord
def preprocess_order_args(order_args)
@klass.disallow_raw_sql!(
flattened_args(order_args),
permit: connection.column_name_with_order_matcher
permit: model.adapter_class.column_name_with_order_matcher
)
validate_order_args(order_args)
@ -2013,7 +2013,7 @@ module ActiveRecord
if attr_name == "count" && !group_values.empty?
table[attr_name]
else
Arel.sql(connection.quote_table_name(attr_name))
Arel.sql(adapter_class.quote_table_name(attr_name))
end
end
end

View File

@ -85,7 +85,7 @@ module ActiveRecord
if condition.is_a?(Array) && condition.first.to_s.include?("?")
disallow_raw_sql!(
[condition.first],
permit: connection.column_name_with_order_matcher
permit: adapter_class.column_name_with_order_matcher
)
# Ensure we aren't dealing with a subclass of String that might
@ -173,7 +173,7 @@ module ActiveRecord
end
end
def disallow_raw_sql!(args, permit: connection.column_name_matcher) # :nodoc:
def disallow_raw_sql!(args, permit: adapter_class.column_name_matcher) # :nodoc:
unexpected = nil
args.each do |arg|
next if arg.is_a?(Symbol) || Arel.arel_node?(arg) || permit.match?(arg.to_s.strip)

View File

@ -27,4 +27,23 @@ class QuotingTest < ActiveRecord::AbstractMysqlTestCase
def test_cast_bound_false
assert_equal "0", @conn.cast_bound_value(false)
end
def test_quote_string
assert_equal "\\'", @conn.quote_string("'")
end
def test_quote_column_name
[@conn, @conn.class].each do |adapter|
assert_equal "`foo`", adapter.quote_column_name("foo")
assert_equal '`hel"lo`', adapter.quote_column_name(%{hel"lo})
end
end
def test_quote_table_name
[@conn, @conn.class].each do |adapter|
assert_equal "`foo`", adapter.quote_table_name("foo")
assert_equal "`foo`.`bar`", adapter.quote_table_name("foo.bar")
assert_equal '`hel"lo.wol\\d`', adapter.quote_column_name('hel"lo.wol\\d')
end
end
end

View File

@ -46,6 +46,25 @@ module ActiveRecord
assert_equal "\"user posts\"", @conn.quote_table_name(value)
end
def test_quote_string
assert_equal "''", @conn.quote_string("'")
end
def test_quote_column_name
[@conn, @conn.class].each do |adapter|
assert_equal '"foo"', adapter.quote_column_name("foo")
assert_equal '"hel""lo"', adapter.quote_column_name(%{hel"lo})
end
end
def test_quote_table_name
[@conn, @conn.class].each do |adapter|
assert_equal '"foo"', adapter.quote_table_name("foo")
assert_equal '"foo"."bar"', adapter.quote_table_name("foo.bar")
assert_equal '"hel""lo.wol\\d"', adapter.quote_column_name('hel"lo.wol\\d')
end
end
def test_raise_when_int_is_wider_than_64bit
value = 9223372036854775807 + 1
assert_raise ActiveRecord::ConnectionAdapters::PostgreSQL::Quoting::IntegerOutOf64BitRange do

View File

@ -10,6 +10,25 @@ class SQLite3QuotingTest < ActiveRecord::SQLite3TestCase
@conn = ActiveRecord::Base.connection
end
def test_quote_string
assert_equal "''", @conn.quote_string("'")
end
def test_quote_column_name
[@conn, @conn.class].each do |adapter|
assert_equal '"foo"', adapter.quote_column_name("foo")
assert_equal '"hel""lo"', adapter.quote_column_name(%{hel"lo})
end
end
def test_quote_table_name
[@conn, @conn.class].each do |adapter|
assert_equal '"foo"', adapter.quote_table_name("foo")
assert_equal '"foo"."bar"', adapter.quote_table_name("foo.bar")
assert_equal '"hel""lo.wol\\d"', adapter.quote_column_name('hel"lo.wol\\d')
end
end
def test_type_cast_binary_encoding_without_logger
@conn.extend(Module.new { def logger; end })
binary = SecureRandom.hex

View File

@ -497,10 +497,6 @@ module ActiveRecord
end
end
def test_quote_string
assert_equal "''", @conn.quote_string("'")
end
def test_insert_logged
with_example_table do
sql = "INSERT INTO ex (number) VALUES (10)"

View File

@ -5,11 +5,24 @@ require "cases/helper"
module ActiveRecord
module ConnectionAdapters
class ColumnDefinitionTest < ActiveRecord::TestCase
def setup
@adapter = AbstractAdapter.new(nil)
def @adapter.native_database_types
class DummyAdapter < AbstractAdapter
class << self
def quote_table_name(table_name)
table_name.to_s
end
def quote_column_name(column_name)
column_name.to_s
end
end
def native_database_types
{ string: "varchar" }
end
end
def setup
@adapter = DummyAdapter.new(nil)
@viz = @adapter.send(:schema_creation)
end

View File

@ -24,15 +24,19 @@ module ActiveRecord
end
def test_quote_column_name
assert_equal "foo", @quoter.quote_column_name("foo")
assert_raises NotImplementedError do
@quoter.quote_column_name("foo")
end
end
def test_quote_table_name
assert_equal "foo", @quoter.quote_table_name("foo")
assert_raises NotImplementedError do
@quoter.quote_table_name("foo")
end
end
def test_quote_table_name_calls_quote_column_name
@quoter.extend(Module.new {
@quoter.class.extend(Module.new {
def quote_column_name(string)
"lol"
end

View File

@ -337,6 +337,10 @@ class FakeKlass
ActiveRecord::Scoping::ScopeRegistry.instance
end
def adapter_class
Post.adapter_class
end
def connection
Post.connection
end