-
-
Notifications
You must be signed in to change notification settings - Fork 50.5k
Expand file tree
/
Copy pathsimons_algorithm.py
More file actions
78 lines (61 loc) · 2.15 KB
/
simons_algorithm.py
File metadata and controls
78 lines (61 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Simon's Algorithm (Classical Simulation)
Simon's algorithm finds a hidden bitstring s such that
f(input_bits) = f(other_bits) if and only if input_bits XOR other_bits = s.
Here we simulate the mapping behavior classically to
illustrate how the hidden period can be discovered by
analyzing collisions in f(input_bits).
References:
https://en.wikipedia.org/wiki/Simon's_problem
"""
from collections.abc import Callable
from itertools import product
def xor_bits(bits1: list[int], bits2: list[int]) -> list[int]:
"""
Return the bitwise XOR of two equal-length bit lists.
>>> xor_bits([1, 0, 1], [1, 1, 0])
[0, 1, 1]
"""
if len(bits1) != len(bits2):
raise ValueError("Bit lists must be of equal length.")
return [x ^ y for x, y in zip(bits1, bits2)]
def simons_algorithm(
hidden_function: Callable[[list[int]], list[int]], num_bits: int
) -> list[int]:
"""
Simulate Simon's algorithm classically to find the hidden bitstring s.
Args:
hidden_function: A function mapping n-bit input to n-bit output.
num_bits: Number of bits in the input.
Returns:
The hidden bitstring s as a list of bits.
>>> # Example with hidden bitstring s = [1, 0, 1]
>>> s = [1, 0, 1]
>>> def hidden_function(input_bits):
... mapping = {
... (0,0,0): (1,1,0),
... (1,0,1): (1,1,0),
... (0,0,1): (0,1,1),
... (1,0,0): (0,1,1),
... (0,1,0): (1,0,1),
... (1,1,1): (1,0,1),
... (0,1,1): (0,0,0),
... (1,1,0): (0,0,0),
... }
... return mapping[tuple(input_bits)]
>>> simons_algorithm(hidden_function, 3)
[1, 0, 1]
"""
mapping: dict[tuple[int, ...], tuple[int, ...]] = {}
inputs = list(product([0, 1], repeat=num_bits))
for bits in inputs:
fx = tuple(hidden_function(list(bits)))
if fx in mapping:
prev_bits = mapping[fx]
return xor_bits(list(bits), list(prev_bits))
mapping[fx] = bits
# If no collision found, function might be constant
return [0] * num_bits
if __name__ == "__main__":
import doctest
doctest.testmod()