11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import enum
14
+
15
15
import logging
16
16
import re
17
- from collections import deque
18
- from dataclasses import dataclass
19
- from typing import (
20
- TYPE_CHECKING ,
21
- Any ,
22
- Collection ,
23
- Iterable ,
24
- List ,
25
- Optional ,
26
- Set ,
27
- Tuple ,
28
- Union ,
29
- )
17
+ from typing import TYPE_CHECKING , Any , Collection , Iterable , List , Optional , Set , Tuple
30
18
31
19
import attr
32
20
39
27
LoggingTransaction ,
40
28
)
41
29
from synapse .storage .databases .main .events_worker import EventRedactBehaviour
42
- from synapse .storage .engines import PostgresEngine , Sqlite3Engine
30
+ from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
43
31
from synapse .types import JsonDict
44
32
45
33
if TYPE_CHECKING :
@@ -433,6 +421,8 @@ async def search_msgs(
433
421
"""
434
422
clauses = []
435
423
424
+ search_query = _parse_query (self .database_engine , search_term )
425
+
436
426
args : List [Any ] = []
437
427
438
428
# Make sure we don't explode because the person is in too many rooms.
@@ -454,24 +444,20 @@ async def search_msgs(
454
444
count_clauses = clauses
455
445
456
446
if isinstance (self .database_engine , PostgresEngine ):
457
- search_query = search_term
458
- tsquery_func = self .database_engine .tsquery_func
459
447
sql = (
460
- f "SELECT ts_rank_cd(vector, { tsquery_func } ('english', ?)) AS rank,"
448
+ "SELECT ts_rank_cd(vector, to_tsquery ('english', ?)) AS rank,"
461
449
" room_id, event_id"
462
450
" FROM event_search"
463
- f " WHERE vector @@ { tsquery_func } ('english', ?)"
451
+ " WHERE vector @@ to_tsquery ('english', ?)"
464
452
)
465
453
args = [search_query , search_query ] + args
466
454
467
455
count_sql = (
468
456
"SELECT room_id, count(*) as count FROM event_search"
469
- f " WHERE vector @@ { tsquery_func } ('english', ?)"
457
+ " WHERE vector @@ to_tsquery ('english', ?)"
470
458
)
471
459
count_args = [search_query ] + count_args
472
460
elif isinstance (self .database_engine , Sqlite3Engine ):
473
- search_query = _parse_query_for_sqlite (search_term )
474
-
475
461
sql = (
476
462
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
477
463
" FROM event_search"
@@ -483,7 +469,7 @@ async def search_msgs(
483
469
"SELECT room_id, count(*) as count FROM event_search"
484
470
" WHERE value MATCH ?"
485
471
)
486
- count_args = [search_query ] + count_args
472
+ count_args = [search_term ] + count_args
487
473
else :
488
474
# This should be unreachable.
489
475
raise Exception ("Unrecognized database engine" )
@@ -515,9 +501,7 @@ async def search_msgs(
515
501
516
502
highlights = None
517
503
if isinstance (self .database_engine , PostgresEngine ):
518
- highlights = await self ._find_highlights_in_postgres (
519
- search_query , events , tsquery_func
520
- )
504
+ highlights = await self ._find_highlights_in_postgres (search_query , events )
521
505
522
506
count_sql += " GROUP BY room_id"
523
507
@@ -526,6 +510,7 @@ async def search_msgs(
526
510
)
527
511
528
512
count = sum (row ["count" ] for row in count_results if row ["room_id" ] in room_ids )
513
+
529
514
return {
530
515
"results" : [
531
516
{"event" : event_map [r ["event_id" ]], "rank" : r ["rank" ]}
@@ -557,6 +542,9 @@ async def search_rooms(
557
542
Each match as a dictionary.
558
543
"""
559
544
clauses = []
545
+
546
+ search_query = _parse_query (self .database_engine , search_term )
547
+
560
548
args : List [Any ] = []
561
549
562
550
# Make sure we don't explode because the person is in too many rooms.
@@ -594,23 +582,20 @@ async def search_rooms(
594
582
args .extend ([origin_server_ts , origin_server_ts , stream ])
595
583
596
584
if isinstance (self .database_engine , PostgresEngine ):
597
- search_query = search_term
598
- tsquery_func = self .database_engine .tsquery_func
599
585
sql = (
600
- f "SELECT ts_rank_cd(vector, { tsquery_func } ('english', ?)) as rank,"
586
+ "SELECT ts_rank_cd(vector, to_tsquery ('english', ?)) as rank,"
601
587
" origin_server_ts, stream_ordering, room_id, event_id"
602
588
" FROM event_search"
603
- f " WHERE vector @@ { tsquery_func } ('english', ?) AND "
589
+ " WHERE vector @@ to_tsquery ('english', ?) AND "
604
590
)
605
591
args = [search_query , search_query ] + args
606
592
607
593
count_sql = (
608
594
"SELECT room_id, count(*) as count FROM event_search"
609
- f " WHERE vector @@ { tsquery_func } ('english', ?) AND "
595
+ " WHERE vector @@ to_tsquery ('english', ?) AND "
610
596
)
611
597
count_args = [search_query ] + count_args
612
598
elif isinstance (self .database_engine , Sqlite3Engine ):
613
-
614
599
# We use CROSS JOIN here to ensure we use the right indexes.
615
600
# https://sqlite.org/optoverview.html#crossjoin
616
601
#
@@ -629,14 +614,13 @@ async def search_rooms(
629
614
" CROSS JOIN events USING (event_id)"
630
615
" WHERE "
631
616
)
632
- search_query = _parse_query_for_sqlite (search_term )
633
617
args = [search_query ] + args
634
618
635
619
count_sql = (
636
620
"SELECT room_id, count(*) as count FROM event_search"
637
621
" WHERE value MATCH ? AND "
638
622
)
639
- count_args = [search_query ] + count_args
623
+ count_args = [search_term ] + count_args
640
624
else :
641
625
# This should be unreachable.
642
626
raise Exception ("Unrecognized database engine" )
@@ -676,9 +660,7 @@ async def search_rooms(
676
660
677
661
highlights = None
678
662
if isinstance (self .database_engine , PostgresEngine ):
679
- highlights = await self ._find_highlights_in_postgres (
680
- search_query , events , tsquery_func
681
- )
663
+ highlights = await self ._find_highlights_in_postgres (search_query , events )
682
664
683
665
count_sql += " GROUP BY room_id"
684
666
@@ -704,7 +686,7 @@ async def search_rooms(
704
686
}
705
687
706
688
async def _find_highlights_in_postgres (
707
- self , search_query : str , events : List [EventBase ], tsquery_func : str
689
+ self , search_query : str , events : List [EventBase ]
708
690
) -> Set [str ]:
709
691
"""Given a list of events and a search term, return a list of words
710
692
that match from the content of the event.
@@ -715,7 +697,6 @@ async def _find_highlights_in_postgres(
715
697
Args:
716
698
search_query
717
699
events: A list of events
718
- tsquery_func: The tsquery_* function to use when making queries
719
700
720
701
Returns:
721
702
A set of strings.
@@ -748,7 +729,7 @@ def f(txn: LoggingTransaction) -> Set[str]:
748
729
while stop_sel in value :
749
730
stop_sel += ">"
750
731
751
- query = f "SELECT ts_headline(?, { tsquery_func } ('english', ?), %s)" % (
732
+ query = "SELECT ts_headline(?, to_tsquery ('english', ?), %s)" % (
752
733
_to_postgres_options (
753
734
{
754
735
"StartSel" : start_sel ,
@@ -779,128 +760,20 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
779
760
return "'%s'" % ("," .join ("%s=%s" % (k , v ) for k , v in options_dict .items ()),)
780
761
781
762
782
- @dataclass
783
- class Phrase :
784
- phrase : List [str ]
785
-
786
-
787
- class SearchToken (enum .Enum ):
788
- Not = enum .auto ()
789
- Or = enum .auto ()
790
- And = enum .auto ()
791
-
792
-
793
- Token = Union [str , Phrase , SearchToken ]
794
- TokenList = List [Token ]
795
-
796
-
797
- def _is_stop_word (word : str ) -> bool :
798
- # TODO Pull these out of the dictionary:
799
- # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
800
- return word in {"the" , "a" , "you" , "me" , "and" , "but" }
801
-
802
-
803
- def _tokenize_query (query : str ) -> TokenList :
804
- """
805
- Convert the user-supplied `query` into a TokenList, which can be translated into
806
- some DB-specific syntax.
807
-
808
- The following constructs are supported:
809
-
810
- - phrase queries using "double quotes"
811
- - case-insensitive `or` and `and` operators
812
- - negation of a keyword via unary `-`
813
- - unary hyphen to denote NOT e.g. 'include -exclude'
814
-
815
- The following differs from websearch_to_tsquery:
816
-
817
- - Stop words are not removed.
818
- - Unclosed phrases are treated differently.
819
-
820
- """
821
- tokens : TokenList = []
822
-
823
- # Find phrases.
824
- in_phrase = False
825
- parts = deque (query .split ('"' ))
826
- for i , part in enumerate (parts ):
827
- # The contents inside double quotes is treated as a phrase, a trailing
828
- # double quote is not implied.
829
- in_phrase = bool (i % 2 ) and i != (len (parts ) - 1 )
830
-
831
- # Pull out the individual words, discarding any non-word characters.
832
- words = deque (re .findall (r"([\w\-]+)" , part , re .UNICODE ))
833
-
834
- # Phrases have simplified handling of words.
835
- if in_phrase :
836
- # Skip stop words.
837
- phrase = [word for word in words if not _is_stop_word (word )]
838
-
839
- # Consecutive words are implicitly ANDed together.
840
- if tokens and tokens [- 1 ] not in (SearchToken .Not , SearchToken .Or ):
841
- tokens .append (SearchToken .And )
842
-
843
- # Add the phrase.
844
- tokens .append (Phrase (phrase ))
845
- continue
846
-
847
- # Otherwise, not in a phrase.
848
- while words :
849
- word = words .popleft ()
850
-
851
- if word .startswith ("-" ):
852
- tokens .append (SearchToken .Not )
853
-
854
- # If there's more word, put it back to be processed again.
855
- word = word [1 :]
856
- if word :
857
- words .appendleft (word )
858
- elif word .lower () == "or" :
859
- tokens .append (SearchToken .Or )
860
- else :
861
- # Skip stop words.
862
- if _is_stop_word (word ):
863
- continue
864
-
865
- # Consecutive words are implicitly ANDed together.
866
- if tokens and tokens [- 1 ] not in (SearchToken .Not , SearchToken .Or ):
867
- tokens .append (SearchToken .And )
868
-
869
- # Add the search term.
870
- tokens .append (word )
871
-
872
- return tokens
873
-
874
-
875
- def _tokens_to_sqlite_match_query (tokens : TokenList ) -> str :
876
- """
877
- Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
878
- Assume sqlite was compiled with enhanced query syntax.
879
-
880
- Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
763
+ def _parse_query (database_engine : BaseDatabaseEngine , search_term : str ) -> str :
764
+ """Takes a plain unicode string from the user and converts it into a form
765
+ that can be passed to database.
766
+ We use this so that we can add prefix matching, which isn't something
767
+ that is supported by default.
881
768
"""
882
- match_query = []
883
- for token in tokens :
884
- if isinstance (token , str ):
885
- match_query .append (token )
886
- elif isinstance (token , Phrase ):
887
- match_query .append ('"' + " " .join (token .phrase ) + '"' )
888
- elif token == SearchToken .Not :
889
- # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
890
- # term has already been added before this.
891
- match_query .append (" NOT " )
892
- elif token == SearchToken .Or :
893
- match_query .append (" OR " )
894
- elif token == SearchToken .And :
895
- match_query .append (" AND " )
896
- else :
897
- raise ValueError (f"unknown token { token } " )
898
-
899
- return "" .join (match_query )
900
769
770
+ # Pull out the individual words, discarding any non-word characters.
771
+ results = re .findall (r"([\w\-]+)" , search_term , re .UNICODE )
901
772
902
- def _parse_query_for_sqlite (search_term : str ) -> str :
903
- """Takes a plain unicode string from the user and converts it into a form
904
- that can be passed to sqllite's matchinfo().
905
- """
906
- return _tokens_to_sqlite_match_query (_tokenize_query (search_term ))
773
+ if isinstance (database_engine , PostgresEngine ):
774
+ return " & " .join (result + ":*" for result in results )
775
+ elif isinstance (database_engine , Sqlite3Engine ):
776
+ return " & " .join (result + "*" for result in results )
777
+ else :
778
+ # This should be unreachable.
779
+ raise Exception ("Unrecognized database engine" )
0 commit comments