-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathtest_solver.py
More file actions
113 lines (91 loc) · 2.47 KB
/
test_solver.py
File metadata and controls
113 lines (91 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Tuple
import numpy as np
import pytest
import autolens as al
import autogalaxy as ag
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles as JAXTriangles,
)
from autolens.mock import NullTracer
from autolens.point.solver import PointSolver
@pytest.fixture
def solver(grid):
return PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.01,
)
def test_solver_basic(solver):
tracer = al.Tracer(
galaxies=[
al.Galaxy(
redshift=0.5,
mass=ag.mp.Isothermal(
centre=(0.0, 0.0),
einstein_radius=1.0,
),
),
al.Galaxy(
redshift=1.0,
),
]
)
assert solver.solve(
tracer=tracer,
source_plane_coordinate=(0.0, 0.0),
)
def test_steps(solver):
assert solver.n_steps == 7
@pytest.mark.parametrize(
"source_plane_coordinate",
[
(0.0, 0.0),
(0.0, 1.0),
(1.0, 0.0),
(1.0, 1.0),
(0.5, 0.5),
(0.1, 0.1),
(-1.0, -1.0),
],
)
def test_trivial(
source_plane_coordinate: Tuple[float, float],
grid,
):
solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.01,
)
coordinates = solver.solve(
tracer=NullTracer(),
source_plane_coordinate=source_plane_coordinate,
)
assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)
def triangle_set(triangles):
return {
tuple(sorted([tuple(np.round(pair, 4)) for pair in triangle]))
for triangle in triangles.triangles.tolist()
if not np.isnan(triangle).any()
}
def test_real_example_jax(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=JAXTriangles,
)
result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)
assert len(result) == 5
def test_real_example_normal(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=CoordinateArrayTriangles,
)
result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)
assert len(result) == 5