Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions autolens/point/solver/shape_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@
from typing import Tuple, List, Iterator, Type, Optional

import autoarray as aa
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles

from autoarray.structures.triangles.shape import Shape
from autofit.jax_wrapper import jit, use_jax, numpy as np, register_pytree_node_class

try:
if use_jax:
from autoarray.structures.triangles.jax_array import (
ArrayTriangles,
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles,
MAX_CONTAINING_SIZE,
)
else:
from autoarray.structures.triangles.array import ArrayTriangles
from autoarray.structures.triangles.coordinate_array.coordinate_array import (
CoordinateArrayTriangles,
)

MAX_CONTAINING_SIZE = None

except ImportError:
from autoarray.structures.triangles.array import ArrayTriangles
from autoarray.structures.triangles.coordinate_array.coordinate_array import (
CoordinateArrayTriangles,
)

MAX_CONTAINING_SIZE = None

Expand Down Expand Up @@ -295,12 +298,11 @@ def _filter_low_magnification(
mask = np.abs(magnifications.array) > self.magnification_threshold
return np.where(mask[:, None], points, np.nan)

def _filtered_triangles(
def _source_triangles(
self,
tracer: OperateDeflections,
triangles: aa.AbstractTriangles,
source_plane_redshift,
shape: Shape,
):
"""
Filter the triangles to keep only those that meet the solver condition
Expand All @@ -310,11 +312,7 @@ def _filtered_triangles(
grid=aa.Grid2DIrregular(triangles.vertices),
source_plane_redshift=source_plane_redshift,
)
source_triangles = triangles.with_vertices(source_plane_grid.array)

indexes = source_triangles.containing_indices(shape=shape)

return triangles.for_indexes(indexes=indexes)
return triangles.with_vertices(source_plane_grid.array)

def steps(
self,
Expand All @@ -340,12 +338,15 @@ def steps(
"""
initial_triangles = self.initial_triangles
for number in range(self.n_steps):
kept_triangles = self._filtered_triangles(
source_triangles = self._source_triangles(
tracer=tracer,
triangles=initial_triangles,
source_plane_redshift=source_plane_redshift,
shape=shape,
)

indexes = source_triangles.containing_indices(shape=shape)
kept_triangles = initial_triangles.for_indexes(indexes=indexes)

neighbourhood = kept_triangles
for _ in range(self.neighbor_degree):
neighbourhood = neighbourhood.neighborhood()
Expand All @@ -358,6 +359,7 @@ def steps(
filtered_triangles=kept_triangles,
neighbourhood=neighbourhood,
up_sampled=up_sampled,
source_triangles=source_triangles,
)

initial_triangles = up_sampled
Expand Down
3 changes: 2 additions & 1 deletion autolens/point/solver/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from autoarray.numpy_wrapper import register_pytree_node_class

try:
from autoarray.structures.triangles.jax_array import ArrayTriangles
from autoarray.structures.triangles.array.jax_array import ArrayTriangles
except ImportError:
from autoarray.structures.triangles.array import ArrayTriangles

Expand Down Expand Up @@ -38,6 +38,7 @@ class Step:
filtered_triangles: aa.AbstractTriangles
neighbourhood: aa.AbstractTriangles
up_sampled: aa.AbstractTriangles
source_triangles: aa.AbstractTriangles

def tree_flatten(self):
return (
Expand Down
24 changes: 22 additions & 2 deletions autolens/point/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,34 @@ def visualise(step: Step):
plt.show()


def plot_triangles(triangles, color="black"):
def plot_triangles(triangles, color="black", title="Triangles", point=None):
plt.figure(figsize=(8, 8))
for triangle in triangles:
triangle = np.append(triangle, [triangle[0]], axis=0)
plt.plot(triangle[:, 0], triangle[:, 1], "o-", color=color)

if point:
plt.plot(point[0], point[1], "x", color="red")

plt.xlabel("X")
plt.ylabel("Y")
plt.title(title)
plt.gca().set_aspect("equal", adjustable="box")
plt.show()


def plot_triangles_compare(triangles_a, triangles_b, number=None):
plt.figure(figsize=(8, 8))
for triangle in triangles_a:
triangle = np.append(triangle, [triangle[0]], axis=0)
plt.plot(triangle[:, 0], triangle[:, 1], "o-", color="red")

for triangle in triangles_b:
triangle = np.append(triangle, [triangle[0]], axis=0)
plt.plot(triangle[:, 0], triangle[:, 1], "o-", color="blue")

plt.xlabel("X")
plt.ylabel("Y")
plt.title(f"Triangles")
plt.title("Triangles" + f" {number}" if number is not None else "")
plt.gca().set_aspect("equal", adjustable="box")
plt.show()
38 changes: 35 additions & 3 deletions test_autolens/point/triangles/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Tuple

import numpy as np
import pytest

import autolens as al
import autogalaxy as ag
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles as JAXTriangles,
)
from autolens.mock import NullTracer
from autolens.point.solver import PointSolver

Expand Down Expand Up @@ -70,12 +75,39 @@ def test_trivial(
assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)


def test_real_example(grid, tracer):
solver = PointSolver.for_grid(
def triangle_set(triangles):
return {
tuple(sorted([tuple(np.round(pair, 4)) for pair in triangle]))
for triangle in triangles.triangles.tolist()
if not np.isnan(triangle).any()
}


def test_real_example_jax(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=JAXTriangles,
)

result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)

assert len(result) == 5


def test_real_example_normal(grid, tracer):
jax_solver = PointSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=CoordinateArrayTriangles,
)

result = solver.solve(tracer=tracer, source_plane_coordinate=(0.07, 0.07))
result = jax_solver.solve(
tracer=tracer,
source_plane_coordinate=(0.07, 0.07),
)

assert len(result) == 5
7 changes: 4 additions & 3 deletions test_autolens/point/triangles/test_solver_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from autolens import PointSolver, Tracer

try:
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
except ImportError:
from autoarray.structures.triangles.jax_coordinate_array import (
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
CoordinateArrayTriangles,
)

except ImportError:
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles

from autolens.mock import NullTracer

pytest.importorskip("jax")
Expand Down