diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md
index d6f081ff186..0eb071eb6f2 100644
--- a/activerecord/CHANGELOG.md
+++ b/activerecord/CHANGELOG.md
@@ -1,3 +1,28 @@
+* Allow passing SQL as `on_duplicate` value to `#upsert_all` to make it possible to use raw SQL to update columns on conflict:
+
+ ```ruby
+ Book.upsert_all(
+ [{ id: 1, status: 1 }, { id: 2, status: 1 }],
+ on_duplicate: Arel.sql("status = GREATEST(books.status, EXCLUDED.status)")
+ )
+ ```
+
+ *Vladimir Dementyev*
+
+* Allow passing SQL as `returning` statement to `#upsert_all`:
+
+ ```ruby
+ Article.insert_all(
+ [
+ { title: "Article 1", slug: "article-1", published: false },
+ { title: "Article 2", slug: "article-2", published: false }
+ ],
+ returning: Arel.sql("id, (xmax = '0') as inserted, name as new_name")
+ )
+ ```
+
+ *Vladimir Dementyev*
+
* Deprecate `legacy_connection_handling`.
*Eileen M. Uchitelle*
diff --git a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb
index 637358807c5..1f9d796fc2f 100644
--- a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb
+++ b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb
@@ -551,8 +551,12 @@ module ActiveRecord
sql << " ON DUPLICATE KEY UPDATE #{no_op_column}=#{no_op_column}"
elsif insert.update_duplicates?
sql << " ON DUPLICATE KEY UPDATE "
- sql << insert.touch_model_timestamps_unless { |column| "#{column}<=>VALUES(#{column})" }
- sql << insert.updatable_columns.map { |column| "#{column}=VALUES(#{column})" }.join(",")
+ if insert.raw_update_sql?
+ sql << insert.raw_update_sql
+ else
+ sql << insert.touch_model_timestamps_unless { |column| "#{column}<=>VALUES(#{column})" }
+ sql << insert.updatable_columns.map { |column| "#{column}=VALUES(#{column})" }.join(",")
+ end
end
sql
diff --git a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb
index 4bdb230a28e..b19bf9baf1e 100644
--- a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb
+++ b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb
@@ -439,8 +439,12 @@ module ActiveRecord
sql << " ON CONFLICT #{insert.conflict_target} DO NOTHING"
elsif insert.update_duplicates?
sql << " ON CONFLICT #{insert.conflict_target} DO UPDATE SET "
- sql << insert.touch_model_timestamps_unless { |column| "#{insert.model.quoted_table_name}.#{column} IS NOT DISTINCT FROM excluded.#{column}" }
- sql << insert.updatable_columns.map { |column| "#{column}=excluded.#{column}" }.join(",")
+ if insert.raw_update_sql?
+ sql << insert.raw_update_sql
+ else
+ sql << insert.touch_model_timestamps_unless { |column| "#{insert.model.quoted_table_name}.#{column} IS NOT DISTINCT FROM excluded.#{column}" }
+ sql << insert.updatable_columns.map { |column| "#{column}=excluded.#{column}" }.join(",")
+ end
end
sql << " RETURNING #{insert.returning}" if insert.returning
diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb
index 965839d752d..a51bb5d1cfb 100644
--- a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb
+++ b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb
@@ -313,8 +313,12 @@ module ActiveRecord
sql << " ON CONFLICT #{insert.conflict_target} DO NOTHING"
elsif insert.update_duplicates?
sql << " ON CONFLICT #{insert.conflict_target} DO UPDATE SET "
- sql << insert.touch_model_timestamps_unless { |column| "#{column} IS excluded.#{column}" }
- sql << insert.updatable_columns.map { |column| "#{column}=excluded.#{column}" }.join(",")
+ if insert.raw_update_sql?
+ sql << insert.raw_update_sql
+ else
+ sql << insert.touch_model_timestamps_unless { |column| "#{column} IS excluded.#{column}" }
+ sql << insert.updatable_columns.map { |column| "#{column}=excluded.#{column}" }.join(",")
+ end
end
sql
diff --git a/activerecord/lib/active_record/insert_all.rb b/activerecord/lib/active_record/insert_all.rb
index 97c440cd9c1..78c81f52cc4 100644
--- a/activerecord/lib/active_record/insert_all.rb
+++ b/activerecord/lib/active_record/insert_all.rb
@@ -5,7 +5,7 @@ require "active_support/core_ext/enumerable"
module ActiveRecord
class InsertAll # :nodoc:
attr_reader :model, :connection, :inserts, :keys
- attr_reader :on_duplicate, :returning, :unique_by
+ attr_reader :on_duplicate, :returning, :unique_by, :update_sql
def initialize(model, inserts, on_duplicate:, returning: nil, unique_by: nil)
raise ArgumentError, "Empty list of attributes passed" if inserts.blank?
@@ -13,6 +13,14 @@ module ActiveRecord
@model, @connection, @inserts, @keys = model, model.connection, inserts, inserts.first.keys.map(&:to_s)
@on_duplicate, @returning, @unique_by = on_duplicate, returning, unique_by
+ disallow_raw_sql!(returning)
+ disallow_raw_sql!(on_duplicate)
+
+ if Arel.arel_node?(on_duplicate)
+ @update_sql = on_duplicate
+ @on_duplicate = :update
+ end
+
if model.scope_attributes?
@scope_attributes = model.scope_attributes
@keys |= @scope_attributes.keys
@@ -127,6 +135,15 @@ module ActiveRecord
end
end
+ def disallow_raw_sql!(value)
+ return if !value.is_a?(String) || Arel.arel_node?(value)
+
+ raise ArgumentError, "Dangerous query method (method whose arguments are used as raw " \
+ "SQL) called: #{value}. " \
+ "Known-safe values can be passed " \
+ "by wrapping them in Arel.sql()."
+ end
+
class Builder # :nodoc:
attr_reader :model
@@ -151,7 +168,13 @@ module ActiveRecord
end
def returning
- format_columns(insert_all.returning) if insert_all.returning
+ return unless insert_all.returning
+
+ if insert_all.returning.is_a?(String)
+ insert_all.returning
+ else
+ format_columns(insert_all.returning)
+ end
end
def conflict_target
@@ -176,6 +199,12 @@ module ActiveRecord
end.compact.join
end
+ def raw_update_sql
+ insert_all.update_sql
+ end
+
+ alias raw_update_sql? raw_update_sql
+
private
attr_reader :connection, :insert_all
diff --git a/activerecord/lib/active_record/persistence.rb b/activerecord/lib/active_record/persistence.rb
index 9406b842742..4a0da374fa9 100644
--- a/activerecord/lib/active_record/persistence.rb
+++ b/activerecord/lib/active_record/persistence.rb
@@ -91,6 +91,9 @@ module ActiveRecord
# or returning: false to omit the underlying RETURNING SQL
# clause entirely.
#
+ # You can also pass an SQL string if you need more control on the return values
+ # (for example, returning: "id, name as new_name").
+ #
# [:unique_by]
# (PostgreSQL and SQLite only) By default rows are considered to be unique
# by every unique index on the table. Any duplicate rows are skipped.
@@ -168,6 +171,9 @@ module ActiveRecord
# or returning: false to omit the underlying RETURNING SQL
# clause entirely.
#
+ # You can also pass an SQL string if you need more control on the return values
+ # (for example, returning: "id, name as new_name").
+ #
# ==== Examples
#
# # Insert multiple records
@@ -192,8 +198,8 @@ module ActiveRecord
# go through Active Record's type casting and serialization.
#
# See ActiveRecord::Persistence#upsert_all for documentation.
- def upsert(attributes, returning: nil, unique_by: nil)
- upsert_all([ attributes ], returning: returning, unique_by: unique_by)
+ def upsert(attributes, on_duplicate: :update, returning: nil, unique_by: nil)
+ upsert_all([ attributes ], on_duplicate: on_duplicate, returning: returning, unique_by: unique_by)
end
# Updates or inserts (upserts) multiple records into the database in a
@@ -216,6 +222,9 @@ module ActiveRecord
# or returning: false to omit the underlying RETURNING SQL
# clause entirely.
#
+ # You can also pass an SQL string if you need more control on the return values
+ # (for example, returning: "id, name as new_name").
+ #
# [:unique_by]
# (PostgreSQL and SQLite only) By default rows are considered to be unique
# by every unique index on the table. Any duplicate rows are skipped.
@@ -236,6 +245,11 @@ module ActiveRecord
# :unique_by is recommended to be paired with
# Active Record's schema_cache.
#
+ # [:on_duplicate]
+ # Specify a custom SQL for updating rows on conflict.
+ #
+ # NOTE: in this case you must provide all the columns you want to update by yourself.
+ #
# ==== Examples
#
# # Inserts multiple records, performing an upsert when records have duplicate ISBNs.
@@ -247,8 +261,8 @@ module ActiveRecord
# ], unique_by: :isbn)
#
# Book.find_by(isbn: "1").title # => "Eloquent Ruby"
- def upsert_all(attributes, returning: nil, unique_by: nil)
- InsertAll.new(self, attributes, on_duplicate: :update, returning: returning, unique_by: unique_by).execute
+ def upsert_all(attributes, on_duplicate: :update, returning: nil, unique_by: nil, update_sql: nil)
+ InsertAll.new(self, attributes, on_duplicate: on_duplicate, returning: returning, unique_by: unique_by).execute
end
# Given an attributes hash, +instantiate+ returns a new instance of
diff --git a/activerecord/test/cases/insert_all_test.rb b/activerecord/test/cases/insert_all_test.rb
index b35c4724fee..902f0157e3d 100644
--- a/activerecord/test/cases/insert_all_test.rb
+++ b/activerecord/test/cases/insert_all_test.rb
@@ -109,6 +109,13 @@ class InsertAllTest < ActiveRecord::TestCase
assert_equal %w[ Rework ], result.pluck("name")
end
+ def test_insert_all_returns_requested_sql_fields
+ skip unless supports_insert_returning?
+
+ result = Book.insert_all! [{ name: "Rework", author_id: 1 }], returning: Arel.sql("UPPER(name) as name")
+ assert_equal %w[ REWORK ], result.pluck("name")
+ end
+
def test_insert_all_can_skip_duplicate_records
skip unless supports_insert_on_duplicate_skip?
@@ -466,6 +473,19 @@ class InsertAllTest < ActiveRecord::TestCase
assert_raise(ArgumentError) { book.subscribers.upsert_all([ { nick: "Jimmy" } ]) }
end
+ def test_upsert_all_updates_using_provided_sql
+ skip unless supports_insert_on_duplicate_update?
+
+ operator = sqlite? ? "MAX" : "GREATEST"
+
+ Book.upsert_all(
+ [{ id: 1, status: 1 }, { id: 2, status: 1 }],
+ on_duplicate: Arel.sql("status = #{operator}(books.status, 1)")
+ )
+ assert_equal "published", Book.find(1).status
+ assert_equal "written", Book.find(2).status
+ end
+
private
def capture_log_output
output = StringIO.new
@@ -477,4 +497,8 @@ class InsertAllTest < ActiveRecord::TestCase
ActiveRecord::Base.logger = old_logger
end
end
+
+ def sqlite?
+ ActiveRecord::Base.connection.adapter_name.match?(/sqlite/i)
+ end
end