Skip to content

Commit fecbcf2

Browse files
committed
Implement search caching for faster results
1 parent 90583c6 commit fecbcf2

1 file changed

Lines changed: 123 additions & 20 deletions

File tree

activity_browser/bwutils/searchengine/metadata_search.py

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@
1212

1313

1414
class MetaDataSearchEngine(SearchEngine):
15+
16+
# caching for faster operation
1517
def database_id_manager(self, database):
1618
if not hasattr(self, "all_database_ids"):
1719
self.all_database_ids = {}
1820

1921
if database_ids := self.all_database_ids.get(database):
2022
self.database_ids = database_ids
23+
self.current_database = database
2124
elif database is not None:
2225
self.database_ids = set(self.df[self.df["database"] == database].index.to_list())
2326
self.all_database_ids[database] = self.database_ids
27+
self.current_database = database
2428
else:
2529
self.database_ids = None
30+
self.current_database = "_@@NO_DB_"
2631
return self.database_ids
2732

2833
def reset_database_id_manager(self):
@@ -31,17 +36,62 @@ def reset_database_id_manager(self):
3136
if hasattr(self, "database_ids"):
3237
del self.database_ids
3338

34-
def add_identifier(self, data: pd.DataFrame) -> None:
35-
super().add_identifier(data)
39+
def database_word_manager(self, database):
40+
if not hasattr(self, "all_database_words"):
41+
self.all_database_words = {}
42+
43+
if database_words := self.all_database_words.get(database):
44+
self.database_words = database_words
45+
elif database is not None:
46+
ids = self.database_id_manager(database)
47+
self.database_words = self.reverse_dict_many_to_one({_id: self.identifier_to_word[_id] for _id in ids})
48+
self.all_database_words[database] = self.database_words
49+
else:
50+
self.database_words = None
51+
return self.database_words
52+
53+
def reset_database_word_manager(self, database):
54+
if hasattr(self, "all_database_words") and self.all_database_words.get(database):
55+
del self.all_database_words[database]
56+
if hasattr(self, "database_words"):
57+
del self.database_words
58+
59+
def database_search_cache(self, database, query, result = None):
60+
if not hasattr(self, "search_cache"):
61+
self.search_cache = {}
62+
63+
if result:
64+
if self.search_cache.get(database):
65+
self.search_cache[database][query] = result
66+
else:
67+
self.search_cache[database] = {query: result}
68+
return
69+
if db_cache := self.search_cache.get(database):
70+
if cached_result := db_cache.get(query):
71+
return cached_result
72+
return
73+
74+
def reset_search_cache(self, database):
75+
if hasattr(self, "search_cache") and self.search_cache.get(database):
76+
del self.search_cache[database]
77+
78+
def reset_all_caches(self, databases):
3679
self.reset_database_id_manager()
80+
for database in databases:
81+
self.reset_database_word_manager(database)
82+
self.reset_search_cache(database)
3783

84+
def add_identifier(self, data: pd.DataFrame) -> None:
85+
super().add_identifier(data)
86+
self.reset_all_caches(data["database"].unique())
3887

3988
def remove_identifiers(self, identifiers, logging=True) -> None:
4089
t = time()
4190

4291
identifiers = set(identifiers)
4392
current_identifiers = set(self.df.index.to_list())
4493
identifiers = identifiers | current_identifiers # only remove identifiers currently in the data
94+
databases = self.df.loc[identifiers, ["databases"]].unique() # extract databases for cache cleaning
4595
if len(identifiers) == 0:
4696
return
4797

@@ -51,11 +101,11 @@ def remove_identifiers(self, identifiers, logging=True) -> None:
51101
if logging:
52102
log.debug(f"Search index updated in {time() - t:.2f} seconds "
53103
f"for {len(identifiers)} removed items ({len(self.df)} items ({self.size_of_index()}) currently).")
54-
self.reset_database_id_manager()
104+
self.reset_all_caches(databases)
55105

56106
def change_identifier(self, identifier, data: pd.DataFrame) -> None:
57107
super().change_identifier(identifier, data)
58-
self.reset_database_id_manager()
108+
self.reset_all_caches(data["database"].unique())
59109

60110
def auto_complete(self, word: str, context: Optional[set] = set(), database: Optional[str] = None) -> list:
61111
"""Based on spellchecker, make more useful for autocompletions
@@ -188,6 +238,53 @@ def find_q_gram_matches(self, q_grams: set, return_all: bool = False) -> pd.Data
188238

189239
return matches.iloc[:min(len(matches), 2500), :] # return at most this many results
190240

241+
def search_size_1(self, queries: list, original_words: set, orig_word_weight=5, exact_word_weight=1) -> dict:
242+
"""Return a dict of {query_word: Counter(identifier)}.
243+
244+
queries: is a list of len 1 tuple/lists of words that are a searched word or a 'spell checked' similar word
245+
original words: a list of words actually searched for (not including spellchecked)
246+
247+
orig_word_weight: additional weight to add to original words
248+
exact_word_weight: additional weight to add to exact word matches (as opposed to be 'in' str)
249+
250+
First, we find all matching words, creating a dict of words in 'queries' as keys and words matching that query word as list of values
251+
Next, we convert this to identifiers and add weights:
252+
Weight will be increased if matching 'orig_word_weight' or 'exact_word_weight'
253+
"""
254+
matches = {}
255+
t2 = time()
256+
# add each word in search index if query_word in word
257+
for word in self.database_words.keys():
258+
for query in queries:
259+
# query is list/tuple of len 1
260+
query_word = query[0] # only use the word
261+
if query_word in word:
262+
words = matches.get(query_word, [])
263+
words.extend([word])
264+
matches[query_word] = words
265+
266+
# now convert matched words to matched identifiers
267+
matched_identifiers = {}
268+
for word, matching_words in matches.items():
269+
if result := self.database_search_cache(self.current_database, word):
270+
matched_identifiers[word] = result
271+
continue
272+
id_counter = matched_identifiers.get(word, Counter())
273+
for matched_word in matching_words:
274+
weight = self.base_weight
275+
276+
# add the word n times, where n is the weight, original search word is weighted higher than alternatives
277+
if matched_word in original_words:
278+
weight += orig_word_weight # increase weight for original word
279+
if matched_word == word:
280+
weight += exact_word_weight # increase weight for exact matching word
281+
282+
id_counter = self.weigh_identifiers(self.database_words[matched_word], weight, id_counter)
283+
matched_identifiers[word] = id_counter
284+
self.database_search_cache(self.current_database, word, matched_identifiers[word])
285+
286+
return matched_identifiers
287+
191288
def fuzzy_search(self, text: str, database: Optional[str] = None, return_counter: bool = False, logging: bool = True) -> list:
192289
"""Overwritten for extra database specific reduction of results.
193290
"""
@@ -200,6 +297,7 @@ def fuzzy_search(self, text: str, database: Optional[str] = None, return_counter
200297

201298
# DATABASE SPECIFIC get the set of ids that is in this database
202299
self.database_id_manager(database)
300+
self.database_word_manager(database)
203301

204302
queries = self.build_queries(text)
205303

@@ -279,17 +377,21 @@ def fuzzy_search(self, text: str, database: Optional[str] = None, return_counter
279377
# now search for all permutations of this query combined with a space
280378
query_df = search_df[search_df[self.identifier_name].isin(query_identifiers)]
281379
for query_perm in permutations(query):
282-
mask = self.filter_dataframe(query_df, " ".join(query_perm), search_columns=["query_col"])
283-
new_df = query_df.loc[mask].reset_index(drop=True)
284-
if len(new_df) == 0:
285-
# there is no match for this permutation of words, skip
286-
continue
287-
new_id_list = new_df[self.identifier_name]
288-
289-
new_ids = Counter()
290-
for new_id in new_id_list:
291-
new_ids[new_id] = query_identifiers[new_id]
292-
380+
query_perm_str = " ".join(query_perm)
381+
if result := self.database_search_cache(self.current_database, query_perm_str):
382+
new_ids = result
383+
else:
384+
mask = self.filter_dataframe(query_df, query_perm_str, search_columns=["query_col"])
385+
new_df = query_df.loc[mask].reset_index(drop=True)
386+
if len(new_df) == 0:
387+
# there is no match for this permutation of words, skip
388+
continue
389+
new_id_list = new_df[self.identifier_name]
390+
391+
new_ids = Counter()
392+
for new_id in new_id_list:
393+
new_ids[new_id] = query_identifiers[new_id]
394+
self.database_search_cache(self.current_database, query_perm_str, new_ids)
293395
# we weigh a combination of words that is next also to each other even higher than just the words separately
294396
query_to_identifier[query_name] = self.weigh_identifiers(new_ids, weight,
295397
query_to_identifier[query_name])
@@ -298,14 +400,15 @@ def fuzzy_search(self, text: str, database: Optional[str] = None, return_counter
298400
for identifiers in query_to_identifier.values():
299401
all_identifiers += identifiers
300402

403+
if return_counter:
404+
return_this = all_identifiers
405+
else:
406+
# now sort on highest weights and make list type
407+
return_this = [identifier[0] for identifier in all_identifiers.most_common()]
301408
if logging:
302409
log.debug(
303410
f"Found {len(all_identifiers)} search results for '{text}' in {len(self.df)} items in {time() - t:.2f} seconds")
304-
if return_counter:
305-
return all_identifiers
306-
# now sort on highest weights and make list type
307-
sorted_identifiers = [identifier[0] for identifier in all_identifiers.most_common()]
308-
return sorted_identifiers
411+
return return_this
309412

310413
def search(self, text, database: Optional[str] = None) -> list:
311414
"""Search the dataframe on this text, return a sorted list of identifiers."""

0 commit comments

Comments
 (0)