From a9603992bdffa88bc9c38395c98a37b96e618fd9 Mon Sep 17 00:00:00 2001 From: Marat Akhmetov Date: Mon, 11 Aug 2025 01:58:18 +0300 Subject: [PATCH] [DOP-23195] update ts --- .../versions/2025-08-10_0012_update_ts.py | 155 ++++++++++++++++++ syncmaster/db/models/connection.py | 14 +- syncmaster/db/models/group.py | 10 +- syncmaster/db/models/queue.py | 9 +- syncmaster/db/models/transfer.py | 31 ++-- syncmaster/db/repositories/base.py | 14 +- syncmaster/db/repositories/connection.py | 6 +- syncmaster/db/repositories/group.py | 10 +- syncmaster/db/repositories/queue.py | 4 +- syncmaster/db/repositories/search.py | 96 +++++++++++ syncmaster/db/repositories/transfer.py | 6 +- 11 files changed, 320 insertions(+), 35 deletions(-) create mode 100644 syncmaster/db/migrations/versions/2025-08-10_0012_update_ts.py create mode 100644 syncmaster/db/repositories/search.py diff --git a/syncmaster/db/migrations/versions/2025-08-10_0012_update_ts.py b/syncmaster/db/migrations/versions/2025-08-10_0012_update_ts.py new file mode 100644 index 00000000..a08a04c1 --- /dev/null +++ b/syncmaster/db/migrations/versions/2025-08-10_0012_update_ts.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +"""Update text search + +Revision ID: 0012 +Revises: 0011 +Create Date: 2025-08-10 20:03:02.105470 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "0012" +down_revision = "0011" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_index(op.f("idx_connection_search_vector"), table_name="connection", postgresql_using="gin") + op.drop_column("connection", "search_vector") + op.drop_column("group", "search_vector") + op.drop_index(op.f("idx_transfer_search_vector"), table_name="transfer", postgresql_using="gin") + op.drop_column("transfer", "search_vector") + op.drop_column("queue", "search_vector") + + op.add_column( + "connection", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "\n to_tsvector('russian', coalesce(name, ''))\n || to_tsvector('simple', coalesce(name, '')) \n || to_tsvector('simple', coalesce(data->>'host', ''))\n || to_tsvector(\n 'simple',\n translate(\n coalesce(data->>'host', ''),\n './-_:\\', ' '\n )\n )\n ", + persisted=True, + ), + nullable=False, + ), + ) + op.create_index( + "idx_connection_search_vector", + "connection", + ["search_vector"], + unique=False, + postgresql_using="gin", + ) + op.add_column( + "group", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "\n to_tsvector('russian', coalesce(name, ''))\n || to_tsvector('simple', coalesce(name, '')) \n ", + persisted=True, + ), + nullable=False, + ), + ) + op.add_column( + "transfer", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "\n to_tsvector('russian', coalesce(name, ''))\n\n || to_tsvector('simple', coalesce(name, ''))\n || to_tsvector('simple', coalesce(source_params->>'table_name', ''))\n || to_tsvector('simple', coalesce(target_params->>'table_name', ''))\n || to_tsvector('simple', coalesce(source_params->>'directory_path', ''))\n || to_tsvector('simple', coalesce(target_params->>'directory_path', ''))\n\n || to_tsvector('simple',\n translate(coalesce(source_params->>'table_name', ''), './-_:\\', ' ')\n )\n || to_tsvector('simple',\n translate(coalesce(target_params->>'table_name', ''), './-_:\\', ' ')\n )\n || to_tsvector('simple',\n translate(coalesce(source_params->>'directory_path', ''), './-_:\\', ' ')\n )\n || to_tsvector('simple',\n translate(coalesce(target_params->>'directory_path', ''), './-_:\\', ' ')\n )\n ", + persisted=True, + ), + nullable=False, + ), + ) + op.create_index("idx_transfer_search_vector", "transfer", ["search_vector"], unique=False, postgresql_using="gin") + op.add_column( + "queue", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "\n to_tsvector('russian', coalesce(name, ''))\n || to_tsvector('simple', coalesce(name, ''))\n ", + persisted=True, + ), + nullable=False, + ), + ) + + +def downgrade() -> None: + op.drop_index("idx_transfer_search_vector", table_name="transfer", postgresql_using="gin") + op.drop_column("transfer", "search_vector") + op.drop_column("group", "search_vector") + op.drop_index("idx_connection_search_vector", table_name="connection", postgresql_using="gin") + op.drop_column("connection", "search_vector") + op.drop_column("queue", "search_vector") + + op.add_column( + "transfer", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "to_tsvector('english'::regconfig, (((((((((((((((((((name)::text || ' '::text) || COALESCE(json_extract_path_text(source_params, VARIADIC ARRAY['table_name'::text]), ''::text)) || ' '::text) || COALESCE(json_extract_path_text(target_params, VARIADIC ARRAY['table_name'::text]), ''::text)) || ' '::text) || COALESCE(json_extract_path_text(source_params, VARIADIC ARRAY['directory_path'::text]), ''::text)) || ' '::text) || COALESCE(json_extract_path_text(target_params, VARIADIC ARRAY['directory_path'::text]), ''::text)) || ' '::text) || translate((name)::text, './'::text, ' '::text)) || ' '::text) || COALESCE(translate(json_extract_path_text(source_params, VARIADIC ARRAY['table_name'::text]), './'::text, ' '::text), ''::text)) || ' '::text) || COALESCE(translate(json_extract_path_text(target_params, VARIADIC ARRAY['table_name'::text]), './'::text, ' '::text), ''::text)) || ' '::text) || COALESCE(translate(json_extract_path_text(source_params, VARIADIC ARRAY['directory_path'::text]), './'::text, ' '::text), ''::text)) || ' '::text) || COALESCE(translate(json_extract_path_text(target_params, VARIADIC ARRAY['directory_path'::text]), './'::text, ' '::text), ''::text)))", + persisted=True, + ), + autoincrement=False, + nullable=False, + ), + ) + op.create_index( + op.f("idx_transfer_search_vector"), + "transfer", + ["search_vector"], + unique=False, + postgresql_using="gin", + ) + op.add_column( + "group", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed("to_tsvector('english'::regconfig, (name)::text)", persisted=True), + autoincrement=False, + nullable=False, + ), + ) + op.add_column( + "connection", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed( + "to_tsvector('english'::regconfig, (((((name)::text || ' '::text) || COALESCE(json_extract_path_text(data, VARIADIC ARRAY['host'::text]), ''::text)) || ' '::text) || COALESCE(translate(json_extract_path_text(data, VARIADIC ARRAY['host'::text]), '.'::text, ' '::text), ''::text)))", + persisted=True, + ), + autoincrement=False, + nullable=False, + ), + ) + op.create_index( + op.f("idx_connection_search_vector"), + "connection", + ["search_vector"], + unique=False, + postgresql_using="gin", + ) + op.add_column( + "queue", + sa.Column( + "search_vector", + postgresql.TSVECTOR(), + sa.Computed("to_tsvector('english'::regconfig, (name)::text)", persisted=True), + autoincrement=False, + nullable=False, + ), + ) diff --git a/syncmaster/db/models/connection.py b/syncmaster/db/models/connection.py index e083a63a..18777d10 100644 --- a/syncmaster/db/models/connection.py +++ b/syncmaster/db/models/connection.py @@ -41,11 +41,15 @@ class Connection(Base, ResourceMixin, TimestampMixin): TSVECTOR, Computed( """ - to_tsvector( - 'english'::regconfig, - name || ' ' || - COALESCE(json_extract_path_text(data, 'host'), '') || ' ' || - COALESCE(translate(json_extract_path_text(data, 'host'), '.', ' '), '') + to_tsvector('russian', coalesce(name, '')) + || to_tsvector('simple', coalesce(name, '')) + || to_tsvector('simple', coalesce(data->>'host', '')) + || to_tsvector( + 'simple', + translate( + coalesce(data->>'host', ''), + './-_:\\', ' ' + ) ) """, persisted=True, diff --git a/syncmaster/db/models/group.py b/syncmaster/db/models/group.py index 0effb47d..d6c0339e 100644 --- a/syncmaster/db/models/group.py +++ b/syncmaster/db/models/group.py @@ -77,13 +77,17 @@ class Group(Base, TimestampMixin): owner: Mapped[User] = relationship(User) queue: Mapped[Queue] = relationship(back_populates="group", cascade="all, delete-orphan") - search_vector: Mapped[str] = mapped_column( TSVECTOR, - Computed("to_tsvector('english'::regconfig, name)", persisted=True), + Computed( + """ + to_tsvector('russian', coalesce(name, '')) + || to_tsvector('simple', coalesce(name, '')) + """, + persisted=True, + ), nullable=False, deferred=True, - doc="Full-text search vector", ) def __repr__(self) -> str: diff --git a/syncmaster/db/models/queue.py b/syncmaster/db/models/queue.py index 00104fc4..c21ab9f0 100644 --- a/syncmaster/db/models/queue.py +++ b/syncmaster/db/models/queue.py @@ -25,10 +25,15 @@ class Queue(Base, ResourceMixin, TimestampMixin): search_vector: Mapped[str] = mapped_column( TSVECTOR, - Computed("to_tsvector('english'::regconfig, name)", persisted=True), + Computed( + """ + to_tsvector('russian', coalesce(name, '')) + || to_tsvector('simple', coalesce(name, '')) + """, + persisted=True, + ), nullable=False, deferred=True, - doc="Full-text search vector", ) def __repr__(self): diff --git a/syncmaster/db/models/transfer.py b/syncmaster/db/models/transfer.py index f68c253b..b8849c24 100644 --- a/syncmaster/db/models/transfer.py +++ b/syncmaster/db/models/transfer.py @@ -65,18 +65,25 @@ class Transfer( TSVECTOR, Computed( """ - to_tsvector( - 'english'::regconfig, - name || ' ' || - COALESCE(json_extract_path_text(source_params, 'table_name'), '') || ' ' || - COALESCE(json_extract_path_text(target_params, 'table_name'), '') || ' ' || - COALESCE(json_extract_path_text(source_params, 'directory_path'), '') || ' ' || - COALESCE(json_extract_path_text(target_params, 'directory_path'), '') || ' ' || - translate(name, './', ' ') || ' ' || - COALESCE(translate(json_extract_path_text(source_params, 'table_name'), './', ' '), '') || ' ' || - COALESCE(translate(json_extract_path_text(target_params, 'table_name'), './', ' '), '') || ' ' || - COALESCE(translate(json_extract_path_text(source_params, 'directory_path'), './', ' '), '') || ' ' || - COALESCE(translate(json_extract_path_text(target_params, 'directory_path'), './', ' '), '') + to_tsvector('russian', coalesce(name, '')) + + || to_tsvector('simple', coalesce(name, '')) + || to_tsvector('simple', coalesce(source_params->>'table_name', '')) + || to_tsvector('simple', coalesce(target_params->>'table_name', '')) + || to_tsvector('simple', coalesce(source_params->>'directory_path', '')) + || to_tsvector('simple', coalesce(target_params->>'directory_path', '')) + + || to_tsvector('simple', + translate(coalesce(source_params->>'table_name', ''), './-_:\\', ' ') + ) + || to_tsvector('simple', + translate(coalesce(target_params->>'table_name', ''), './-_:\\', ' ') + ) + || to_tsvector('simple', + translate(coalesce(source_params->>'directory_path', ''), './-_:\\', ' ') + ) + || to_tsvector('simple', + translate(coalesce(target_params->>'directory_path', ''), './-_:\\', ' ') ) """, persisted=True, diff --git a/syncmaster/db/repositories/base.py b/syncmaster/db/repositories/base.py index 43494ef3..f60ece14 100644 --- a/syncmaster/db/repositories/base.py +++ b/syncmaster/db/repositories/base.py @@ -3,7 +3,16 @@ from abc import ABC from typing import Any, Generic, TypeVar -from sqlalchemy import ScalarResult, Select, delete, func, insert, select, update +from sqlalchemy import ( + ColumnElement, + ScalarResult, + Select, + delete, + func, + insert, + select, + update, +) from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncSession @@ -98,8 +107,7 @@ async def _paginate_scalar_result(self, query: Select, page: int, page_size: int page_size=page_size, ) - def _construct_vector_search(self, query: Select, search_query: str) -> Select: - ts_query = func.plainto_tsquery("english", search_query) + def _construct_vector_search(self, query: Select, ts_query: ColumnElement) -> Select: query = ( query.where(self._model.search_vector.op("@@")(ts_query)) .add_columns(func.ts_rank(self._model.search_vector, ts_query).label("rank")) diff --git a/syncmaster/db/repositories/connection.py b/syncmaster/db/repositories/connection.py index 8887a557..e03baf21 100644 --- a/syncmaster/db/repositories/connection.py +++ b/syncmaster/db/repositories/connection.py @@ -8,6 +8,7 @@ from syncmaster.db.models import Connection from syncmaster.db.repositories.repository_with_owner import RepositoryWithOwner +from syncmaster.db.repositories.search import make_tsquery from syncmaster.db.utils import Pagination from syncmaster.exceptions import EntityNotFoundError, SyncmasterError from syncmaster.exceptions.connection import ( @@ -35,9 +36,8 @@ async def paginate( Connection.group_id == group_id, ) if search_query: - processed_query = search_query.replace(".", " ") - combined_query = f"{search_query} {processed_query}" - stmt = self._construct_vector_search(stmt, combined_query) + ts_query = make_tsquery(search_query) + stmt = self._construct_vector_search(stmt, ts_query) if connection_type is not None: stmt = stmt.where(Connection.type.in_(connection_type)) diff --git a/syncmaster/db/repositories/group.py b/syncmaster/db/repositories/group.py index 7a23fa1e..3dd4aa05 100644 --- a/syncmaster/db/repositories/group.py +++ b/syncmaster/db/repositories/group.py @@ -10,6 +10,7 @@ from syncmaster.db.models import Group, GroupMemberRole, User, UserGroup from syncmaster.db.repositories.base import Repository +from syncmaster.db.repositories.search import make_tsquery from syncmaster.db.utils import Pagination, Permission from syncmaster.exceptions import EntityNotFoundError, SyncmasterError from syncmaster.exceptions.group import ( @@ -33,7 +34,8 @@ async def paginate_all( ) -> Pagination: stmt = select(Group) if search_query: - stmt = self._construct_vector_search(stmt, search_query) + ts_query = make_tsquery(search_query) + stmt = self._construct_vector_search(stmt, ts_query) paginated_result = await self._paginate_scalar_result( query=stmt.order_by(Group.name), @@ -78,7 +80,8 @@ async def paginate_for_user( # apply search filtering if a search query is provided if search_query: - owned_groups_stmt = self._construct_vector_search(owned_groups_stmt, search_query) + ts_query = make_tsquery(search_query) + owned_groups_stmt = self._construct_vector_search(owned_groups_stmt, ts_query) # get total count of owned groups total_owned_groups = ( @@ -114,7 +117,8 @@ async def paginate_for_user( # apply search filtering if a search query is provided if search_query: - user_groups_stmt = self._construct_vector_search(user_groups_stmt, search_query) + ts_query = make_tsquery(search_query) + user_groups_stmt = self._construct_vector_search(user_groups_stmt, ts_query) # get total count of user groups total_user_groups = ( diff --git a/syncmaster/db/repositories/queue.py b/syncmaster/db/repositories/queue.py index e2fe56f5..ed745e95 100644 --- a/syncmaster/db/repositories/queue.py +++ b/syncmaster/db/repositories/queue.py @@ -9,6 +9,7 @@ from syncmaster.db.models import Group, GroupMemberRole, Queue, User, UserGroup from syncmaster.db.repositories.repository_with_owner import RepositoryWithOwner +from syncmaster.db.repositories.search import make_tsquery from syncmaster.db.utils import Permission from syncmaster.exceptions import EntityNotFoundError, SyncmasterError from syncmaster.exceptions.group import GroupNotFoundError @@ -59,7 +60,8 @@ async def paginate( Queue.group_id == group_id, ) if search_query: - stmt = self._construct_vector_search(stmt, search_query) + ts_query = make_tsquery(search_query) + stmt = self._construct_vector_search(stmt, ts_query) return await self._paginate_scalar_result( query=stmt.order_by(Queue.id), diff --git a/syncmaster/db/repositories/search.py b/syncmaster/db/repositories/search.py new file mode 100644 index 00000000..f154283f --- /dev/null +++ b/syncmaster/db/repositories/search.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: 2024-2025 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from enum import IntFlag +from string import punctuation + +from sqlalchemy import ColumnElement, func +from sqlalchemy.orm import InstrumentedAttribute + +# left some punctuation to match file paths, URLs and host names +TSQUERY_UNSUPPORTED_CHARS = "".join(sorted(set(punctuation) - {"/", ".", "_", "-"})) +TSQUERY_UNSUPPORTED_CHARS_REPLACEMENT = str.maketrans(TSQUERY_UNSUPPORTED_CHARS, " " * len(TSQUERY_UNSUPPORTED_CHARS)) +TSQUERY_ALL_PUNCTUATION_REPLACEMENT = str.maketrans(punctuation, " " * len(punctuation)) + + +class SearchRankNormalization(IntFlag): + """See https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-RANKING""" + + IGNORE_LENGTH = 0 + DOCUMENT_LENGTH_LOGARITHM = 1 + DOCUMENT_LENGTH = 2 + HARMONIC_DISTANCE = 4 + UNIQUE_WORDS = 8 + UNIQUE_WORDS_LOGARITHM = 16 + RANK_PLUS_ONE = 32 + + +def ts_rank(search_vector: InstrumentedAttribute, ts_query: ColumnElement) -> ColumnElement: + """Get ts_rank for search query ranking. + + Places results with smaller number of total words (like table name) to the top, + and long results (as file paths) to the bottom. + + Also places on top results with lexemes order matching the tsvector order. + """ + return func.ts_rank_cd(search_vector, ts_query, SearchRankNormalization.UNIQUE_WORDS) + + +def make_tsquery(user_input: str) -> ColumnElement: + """Convert user input to tsquery. + + - wraps tokens with `:*` for prefix matching, + - combines unstemmed 'simple' query with stemmed 'russian' via OR. + """ + simple_query = func.to_tsquery("simple", build_tsquery(user_input)) + + stemmed_query = func.plainto_tsquery("russian", user_input) + + combined_query = simple_query.op("||")(stemmed_query) + + return combined_query + + +def ts_match(search_vector: InstrumentedAttribute, ts_query: ColumnElement) -> ColumnElement: + """Build an expression to get only search_vector matching ts_query.""" + return search_vector.op("@@")(ts_query) + + +def build_tsquery(user_input: str) -> str: + original_words = words_with_supported_punctuation(user_input) + only_words = words_without_any_punctuation(user_input) + + return combine_queries( + combine_words(*original_words, by_prefix=False), + combine_words(*original_words), + combine_words(*only_words) if only_words != original_words else [], + ) + + +def combine_words(*words: Sequence[str], by_prefix: bool = True) -> str: + # Convert this ['some', 'query'] + # to this `'some' & 'query'` or `'some':* & 'query':*` + modifier = ":*" if by_prefix else "" + return " & ".join(f"'{word}'{modifier}" for word in words if word) + + +def combine_queries(*queries: Sequence[str]) -> str: + # Convert this ['/some/file/path:* & abc:*', 'some:* & file:* & path:*'] + # to this '(/some/file/path:* & abc:*) | (some:* & file:* & path:* & abc:*)' + return " | ".join(f"({query})" for query in queries if query) + + +def words_with_supported_punctuation(query: str) -> list[str]: + # convert '@/some/path.or.domain!' -> '/some/path.or.domain' + converted = query.translate(TSQUERY_UNSUPPORTED_CHARS_REPLACEMENT) + result = [] + for part in converted.split(): + # delete parts containing only punctuation chars, like ./ + if not all(char in punctuation for char in part): + result.append(part) # noqa: PERF401 + return result + + +def words_without_any_punctuation(query: str) -> list[str]: + # convert '@/some/path.or.domain!' -> 'some path or domain' + return query.translate(TSQUERY_ALL_PUNCTUATION_REPLACEMENT).split() diff --git a/syncmaster/db/repositories/transfer.py b/syncmaster/db/repositories/transfer.py index d51ca94a..e8fae9a3 100644 --- a/syncmaster/db/repositories/transfer.py +++ b/syncmaster/db/repositories/transfer.py @@ -10,6 +10,7 @@ from syncmaster.db.models import Connection, Transfer from syncmaster.db.repositories.repository_with_owner import RepositoryWithOwner +from syncmaster.db.repositories.search import make_tsquery from syncmaster.db.utils import Pagination from syncmaster.exceptions import EntityNotFoundError, SyncmasterError from syncmaster.exceptions.connection import ConnectionNotFoundError @@ -45,9 +46,8 @@ async def paginate( ) if search_query: - processed_query = search_query.replace("/", " ").replace(".", " ") - combined_query = f"{search_query} {processed_query}" - stmt = self._construct_vector_search(stmt, combined_query) + ts_query = make_tsquery(search_query) + stmt = self._construct_vector_search(stmt, ts_query) if source_connection_id is not None: stmt = stmt.where(Transfer.source_connection_id == source_connection_id)