|
16 | 16 | import itertools |
17 | 17 | import json |
18 | 18 | import logging |
19 | | -from typing import Dict, Iterable, Mapping, Optional, Tuple |
| 19 | +from typing import Dict, Iterable, Optional, Tuple |
20 | 20 |
|
| 21 | +from canonicaljson import encode_canonical_json |
21 | 22 | from signedjson.key import decode_verify_key_bytes |
22 | 23 | from unpaddedbase64 import decode_base64 |
23 | 24 |
|
| 25 | +from synapse.storage.database import LoggingTransaction |
24 | 26 | from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore |
25 | 27 | from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote |
26 | 28 | from synapse.storage.types import Cursor |
| 29 | +from synapse.types import JsonDict |
27 | 30 | from synapse.util.caches.descriptors import cached, cachedList |
28 | 31 | from synapse.util.iterutils import batch_iter |
29 | 32 |
|
|
36 | 39 | class KeyStore(CacheInvalidationWorkerStore): |
37 | 40 | """Persistence for signature verification keys""" |
38 | 41 |
|
39 | | - @cached() |
40 | | - def _get_server_signature_key( |
41 | | - self, server_name_and_key_id: Tuple[str, str] |
42 | | - ) -> FetchKeyResult: |
43 | | - raise NotImplementedError() |
44 | | - |
45 | | - @cachedList( |
46 | | - cached_method_name="_get_server_signature_key", |
47 | | - list_name="server_name_and_key_ids", |
48 | | - ) |
49 | | - async def get_server_signature_keys( |
50 | | - self, server_name_and_key_ids: Iterable[Tuple[str, str]] |
51 | | - ) -> Dict[Tuple[str, str], FetchKeyResult]: |
52 | | - """ |
53 | | - Args: |
54 | | - server_name_and_key_ids: |
55 | | - iterable of (server_name, key-id) tuples to fetch keys for |
56 | | -
|
57 | | - Returns: |
58 | | - A map from (server_name, key_id) -> FetchKeyResult, or None if the |
59 | | - key is unknown |
60 | | - """ |
61 | | - keys = {} |
62 | | - |
63 | | - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: |
64 | | - """Processes a batch of keys to fetch, and adds the result to `keys`.""" |
65 | | - |
66 | | - # batch_iter always returns tuples so it's safe to do len(batch) |
67 | | - sql = """ |
68 | | - SELECT server_name, key_id, verify_key, ts_valid_until_ms |
69 | | - FROM server_signature_keys WHERE 1=0 |
70 | | - """ + " OR (server_name=? AND key_id=?)" * len( |
71 | | - batch |
72 | | - ) |
73 | | - |
74 | | - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) |
75 | | - |
76 | | - for row in txn: |
77 | | - server_name, key_id, key_bytes, ts_valid_until_ms = row |
78 | | - |
79 | | - if ts_valid_until_ms is None: |
80 | | - # Old keys may be stored with a ts_valid_until_ms of null, |
81 | | - # in which case we treat this as if it was set to `0`, i.e. |
82 | | - # it won't match key requests that define a minimum |
83 | | - # `ts_valid_until_ms`. |
84 | | - ts_valid_until_ms = 0 |
85 | | - |
86 | | - keys[(server_name, key_id)] = FetchKeyResult( |
87 | | - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), |
88 | | - valid_until_ts=ts_valid_until_ms, |
89 | | - ) |
90 | | - |
91 | | - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: |
92 | | - for batch in batch_iter(server_name_and_key_ids, 50): |
93 | | - _get_keys(txn, batch) |
94 | | - return keys |
95 | | - |
96 | | - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) |
97 | | - |
98 | | - async def store_server_signature_keys( |
| 42 | + async def store_server_keys_response( |
99 | 43 | self, |
| 44 | + server_name: str, |
100 | 45 | from_server: str, |
101 | 46 | ts_added_ms: int, |
102 | | - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], |
| 47 | + verify_keys: Dict[str, FetchKeyResult], |
| 48 | + response_json: JsonDict, |
103 | 49 | ) -> None: |
104 | | - """Stores NACL verification keys for remote servers. |
| 50 | + """Stores the keys for the given server that we got from `from_server`. |
| 51 | +
|
105 | 52 | Args: |
106 | | - from_server: Where the verification keys were looked up |
107 | | - ts_added_ms: The time to record that the key was added |
108 | | - verify_keys: |
109 | | - keys to be stored. Each entry is a triplet of |
110 | | - (server_name, key_id, key). |
| 53 | + server_name: The owner of the keys |
| 54 | + from_server: Which server we got the keys from |
| 55 | + ts_added_ms: When we're adding the keys |
| 56 | + verify_keys: The decoded keys |
| 57 | + response_json: The full *signed* response JSON that contains the keys. |
111 | 58 | """ |
112 | | - key_values = [] |
113 | | - value_values = [] |
114 | | - invalidations = [] |
115 | | - for (server_name, key_id), fetch_result in verify_keys.items(): |
116 | | - key_values.append((server_name, key_id)) |
117 | | - value_values.append( |
118 | | - ( |
119 | | - from_server, |
120 | | - ts_added_ms, |
121 | | - fetch_result.valid_until_ts, |
122 | | - db_binary_type(fetch_result.verify_key.encode()), |
123 | | - ) |
124 | | - ) |
125 | | - # invalidate takes a tuple corresponding to the params of |
126 | | - # _get_server_signature_key. _get_server_signature_key only takes one |
127 | | - # param, which is itself the 2-tuple (server_name, key_id). |
128 | | - invalidations.append((server_name, key_id)) |
129 | 59 |
|
130 | | - await self.db_pool.simple_upsert_many( |
131 | | - table="server_signature_keys", |
132 | | - key_names=("server_name", "key_id"), |
133 | | - key_values=key_values, |
134 | | - value_names=( |
135 | | - "from_server", |
136 | | - "ts_added_ms", |
137 | | - "ts_valid_until_ms", |
138 | | - "verify_key", |
139 | | - ), |
140 | | - value_values=value_values, |
141 | | - desc="store_server_signature_keys", |
142 | | - ) |
| 60 | + key_json_bytes = encode_canonical_json(response_json) |
| 61 | + |
| 62 | + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: |
| 63 | + self.db_pool.simple_upsert_many_txn( |
| 64 | + txn, |
| 65 | + table="server_signature_keys", |
| 66 | + key_names=("server_name", "key_id"), |
| 67 | + key_values=[(server_name, key_id) for key_id in verify_keys], |
| 68 | + value_names=( |
| 69 | + "from_server", |
| 70 | + "ts_added_ms", |
| 71 | + "ts_valid_until_ms", |
| 72 | + "verify_key", |
| 73 | + ), |
| 74 | + value_values=[ |
| 75 | + ( |
| 76 | + from_server, |
| 77 | + ts_added_ms, |
| 78 | + fetch_result.valid_until_ts, |
| 79 | + db_binary_type(fetch_result.verify_key.encode()), |
| 80 | + ) |
| 81 | + for fetch_result in verify_keys.values() |
| 82 | + ], |
| 83 | + ) |
143 | 84 |
|
144 | | - invalidate = self._get_server_signature_key.invalidate |
145 | | - for i in invalidations: |
146 | | - invalidate((i,)) |
| 85 | + self.db_pool.simple_upsert_many_txn( |
| 86 | + txn, |
| 87 | + table="server_keys_json", |
| 88 | + key_names=("server_name", "key_id", "from_server"), |
| 89 | + key_values=[ |
| 90 | + (server_name, key_id, from_server) for key_id in verify_keys |
| 91 | + ], |
| 92 | + value_names=( |
| 93 | + "ts_added_ms", |
| 94 | + "ts_valid_until_ms", |
| 95 | + "key_json", |
| 96 | + ), |
| 97 | + value_values=[ |
| 98 | + ( |
| 99 | + ts_added_ms, |
| 100 | + fetch_result.valid_until_ts, |
| 101 | + db_binary_type(key_json_bytes), |
| 102 | + ) |
| 103 | + for fetch_result in verify_keys.values() |
| 104 | + ], |
| 105 | + ) |
147 | 106 |
|
148 | | - async def store_server_keys_json( |
149 | | - self, |
150 | | - server_name: str, |
151 | | - key_id: str, |
152 | | - from_server: str, |
153 | | - ts_now_ms: int, |
154 | | - ts_expires_ms: int, |
155 | | - key_json_bytes: bytes, |
156 | | - ) -> None: |
157 | | - """Stores the JSON bytes for a set of keys from a server |
158 | | - The JSON should be signed by the originating server, the intermediate |
159 | | - server, and by this server. Updates the value for the |
160 | | - (server_name, key_id, from_server) triplet if one already existed. |
161 | | - Args: |
162 | | - server_name: The name of the server. |
163 | | - key_id: The identifier of the key this JSON is for. |
164 | | - from_server: The server this JSON was fetched from. |
165 | | - ts_now_ms: The time now in milliseconds. |
166 | | - ts_valid_until_ms: The time when this json stops being valid. |
167 | | - key_json_bytes: The encoded JSON. |
168 | | - """ |
169 | | - await self.db_pool.simple_upsert( |
170 | | - table="server_keys_json", |
171 | | - keyvalues={ |
172 | | - "server_name": server_name, |
173 | | - "key_id": key_id, |
174 | | - "from_server": from_server, |
175 | | - }, |
176 | | - values={ |
177 | | - "server_name": server_name, |
178 | | - "key_id": key_id, |
179 | | - "from_server": from_server, |
180 | | - "ts_added_ms": ts_now_ms, |
181 | | - "ts_valid_until_ms": ts_expires_ms, |
182 | | - "key_json": db_binary_type(key_json_bytes), |
183 | | - }, |
184 | | - desc="store_server_keys_json", |
185 | | - ) |
| 107 | + # invalidate takes a tuple corresponding to the params of |
| 108 | + # _get_server_keys_json. _get_server_keys_json only takes one |
| 109 | + # param, which is itself the 2-tuple (server_name, key_id). |
| 110 | + for key_id in verify_keys: |
| 111 | + self._invalidate_cache_and_stream( |
| 112 | + txn, self._get_server_keys_json, ((server_name, key_id),) |
| 113 | + ) |
| 114 | + self._invalidate_cache_and_stream( |
| 115 | + txn, self.get_server_key_json_for_remote, (server_name, key_id) |
| 116 | + ) |
186 | 117 |
|
187 | | - # invalidate takes a tuple corresponding to the params of |
188 | | - # _get_server_keys_json. _get_server_keys_json only takes one |
189 | | - # param, which is itself the 2-tuple (server_name, key_id). |
190 | | - await self.invalidate_cache_and_stream( |
191 | | - "_get_server_keys_json", ((server_name, key_id),) |
192 | | - ) |
193 | | - await self.invalidate_cache_and_stream( |
194 | | - "get_server_key_json_for_remote", (server_name, key_id) |
| 118 | + await self.db_pool.runInteraction( |
| 119 | + "store_server_keys_response", store_server_keys_response_txn |
195 | 120 | ) |
196 | 121 |
|
197 | 122 | @cached() |
|
0 commit comments