canvas-lms/lib/smart_search.rb

270 lines
9.9 KiB
Ruby

# frozen_string_literal: true
# Copyright (C) 2023 - present Instructure, Inc.
#
# This file is part of Canvas.
#
# Canvas is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, version 3 of the License.
#
# Canvas is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License along
# with this program. If not, see <http://www.gnu.org/licenses/>.
require "aws-sdk-bedrockruntime"
module SmartSearch
EMBEDDING_VERSION = 2
CHUNK_MAX_LENGTH = 1500
class << self
def api_key
Rails.application.credentials.dig(:smart_search, :openai_api_token)
end
def bedrock_client
return @bedrock_client if instance_variable_defined?(:@bedrock_client)
# for local dev, assume that we are using creds from inseng (us-west-2)
settings = YAML.safe_load(DynamicSettings.find(tree: :private)["bedrock.yml"] || "{}")
config = {
region: settings["bedrock_region"] || "us-west-2"
}
# Will load creds from vault (prod) or rails credential store (local / oss).
# Credentials stored in rails credential store in the `bedrock_creds` key
# with `aws_access_key_id` and `aws_secret_access_key` keys
config[:credentials] = Canvas::AwsCredentialProvider.new("bedrock_creds", settings["vault_credential_path"])
@bedrock_client = if config[:credentials].set?
Aws::BedrockRuntime::Client.new(config)
end
end
def smart_search_available?(context)
context&.feature_enabled?(:smart_search) && bedrock_client.present?
end
def register_class(klass, index_scope_proc, search_scope_proc)
@search_info ||= []
@search_info << [klass, index_scope_proc, search_scope_proc]
end
def index_scopes(course)
@search_info.map do |_, proc, _|
proc.call(course)
end
end
def search_scopes(course, user)
@search_info.map do |klass, _, proc|
[klass, proc.call(course, user)]
end
end
def generate_embedding(input, query: false, version: EMBEDDING_VERSION)
case version
when 1
generate_embedding_v1(input)
when 2
generate_embedding_v2(input, query)
else
raise ArgumentError, "Unsupported embedding version #{version}"
end
end
def generate_embedding_v1(input)
# NOTE: openai does not differentiate between query and document embeddings
url = "https://api.openai.com/v1/embeddings"
headers = {
"Authorization" => "Bearer #{api_key}",
"Content-Type" => "application/json"
}
data = {
input:,
model: "text-embedding-ada-002"
}
response = JSON.parse(Net::HTTP.post(URI(url), data.to_json, headers).body)
raise response["error"]["message"] if response["error"]
response["data"].pluck("embedding")[0]
end
def generate_embedding_v2(input, query)
resp = bedrock_client.invoke_model({
content_type: "application/json",
accept: "application/json",
model_id: "cohere.embed-multilingual-v3",
body: {
texts: [input],
input_type: query ? "search_query" : "search_document"
}.to_json,
})
json = JSON.parse(resp.body.string)
json["embeddings"][0]
end
def perform_search(context, user, query, type_filter = [])
version = context.search_embedding_version || EMBEDDING_VERSION
embedding = SmartSearch.generate_embedding(query, version:, query: true)
collections = []
ActiveRecord::Base.with_pgvector do
SmartSearch.search_scopes(context, user).each do |klass, item_scope|
item_scope = apply_filter(klass, item_scope, type_filter)
next unless item_scope
item_scope = item_scope.select(
ActiveRecord::Base.send(:sanitize_sql, ["#{klass.table_name}.*, MIN(embedding OPERATOR(#{PG::Connection.quote_ident(ActiveRecord::Base.connection.extension("vector").schema)}.<=>) ?) AS distance", embedding.to_s])
)
.joins(:embeddings)
.where(klass.embedding_class.table_name => { version: })
.group("#{klass.table_name}.id")
.reorder("distance ASC")
collections << [klass.name,
BookmarkedCollection.wrap(
BookmarkedCollection::SimpleBookmarker.new(klass, { distance: { type: :float, null: false } }, :id),
item_scope
)]
end
end
BookmarkedCollection.merge(*collections)
end
def apply_filter(klass, scope, filter)
return scope if filter.empty?
if klass == DiscussionTopic
if filter.include?("discussion_topics") && filter.include?("announcements")
scope
elsif filter.include?("discussion_topics")
scope.where(type: nil)
elsif filter.include?("announcements")
scope.where(type: "Announcement")
end
elsif filter.include?(Context.api_type_name(klass))
scope
end
end
def result_relevance(object)
version = object.context.try(:search_embeddings_version) || EMBEDDING_VERSION
case version
when 1
(100.0 * (1.0 - object.distance)).round
when 2
# this function stretches out the useful range of distances;
# otherwise everything would be 40-60% relevant using the old formula
(100.0 * ((2.0 / (1.0 + Math.exp(-18.0 * ((1.0 - object.distance)**3)))) - 1.0)).round
end
end
def up_to_date?(course)
smart_search_available?(course) && course.search_embedding_version == SmartSearch::EMBEDDING_VERSION
end
# returns [ready, progress]
# progress may be < 100 while ready if upgrading embeddings
def check_course(course)
return -1 unless smart_search_available?(course)
if course.search_embedding_version == EMBEDDING_VERSION
[true, 100]
else
# queue the index job (the singleton will ensure it's only queued once)
delay(singleton: "smart_search_index_course_#{course.global_id}").index_course(course)
[course.search_embedding_version.present?, indexing_progress(course)]
end
end
def index_course(course)
return if course.search_embedding_version == EMBEDDING_VERSION
# TODO: investigate pipelining this
index_scopes(course).each do |scope|
scope.left_joins(:embeddings)
.group("#{scope.table_name}.id")
.having("COALESCE(MAX(#{scope.embedding_class.table_name}.version), 0) < ?", EMBEDDING_VERSION)
.find_each(strategy: :pluck_ids) do |item|
item.generate_embeddings(synchronous: true)
end
end
course.update!(search_embedding_version: EMBEDDING_VERSION)
delay_if_production(priority: Delayed::LOW_PRIORITY).delete_old_embeddings(course)
end
def delete_old_embeddings(course)
index_scopes(course).each do |scope|
scope.embedding_class
.where(scope.embedding_foreign_key => scope.except(:order).select(:id))
.where(version: ...EMBEDDING_VERSION)
.in_batches
.delete_all
end
nil
end
def indexing_progress(course)
return 100 if course.search_embedding_version == EMBEDDING_VERSION
total = 0
indexed = 0
index_scopes(course).each do |scope|
n, i = scope.except(:order).left_joins(:embeddings).pick(Arel.sql(<<~SQL.squish))
COUNT(DISTINCT #{scope.table_name}.id),
COUNT(DISTINCT CASE WHEN #{scope.embedding_class.table_name}.version = #{EMBEDDING_VERSION} THEN #{scope.table_name}.id END)
SQL
total += n
indexed += i
end
(indexed * 100.0 / total).to_i
end
def copy_embeddings(content_migration)
return unless content_migration.for_course_copy? &&
content_migration.source_course&.search_embedding_version == EMBEDDING_VERSION &&
SmartSearch.smart_search_available?(content_migration.context)
content_migration.imported_asset_id_map&.each do |class_name, id_mapping|
klass = class_name.safe_constantize
next unless klass.respond_to?(:embedding_class)
fk = klass.embedding_foreign_key # i.e. :wiki_page_id
content_migration.context.shard.activate do
klass.embedding_class
.where(:version => EMBEDDING_VERSION, fk => id_mapping.values)
.in_batches
.delete_all
end
content_migration.source_course.shard.activate do
klass.embedding_class.where(:version => EMBEDDING_VERSION, fk => id_mapping.keys)
.find_in_batches(batch_size: 50) do |src_embeddings|
dest_embeddings = src_embeddings.map do |src_embedding|
{
:embedding => src_embedding.embedding,
fk => id_mapping[src_embedding[fk]],
:version => EMBEDDING_VERSION,
:root_account_id => content_migration.context.root_account_id,
:created_at => Time.now.utc,
:updated_at => Time.now.utc
}
end
content_migration.context.shard.activate do
klass.embedding_class.insert_all(dest_embeddings)
end
end
end
end
end
end
end