Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions autolens/point/fit/positions/image/pair_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,134 @@ def residual_map(self) -> aa.ArrayIrregular:
residual_map.append(np.sqrt(min(distances)))

return aa.ArrayIrregular(values=residual_map)


class Fit:
def __init__(
self,
data: aa.Grid2DIrregular,
noise_map: aa.ArrayIrregular,
model_positions: np.ndarray,
):
"""
Compare the multiple image points observed to those produced by a model.

Parameters
----------
data
Observed multiple image coordinates
noise_map
The noise associated with each observed image coordinate
model_positions
The multiple image coordinates produced by the model
"""
self.data = data
self.noise_map = noise_map
self.model_positions = model_positions

@staticmethod
def square_distance(
coord1: np.array,
coord2: np.array,
) -> float:
"""
Calculate the square distance between two points.

Parameters
----------
coord1
coord2
The two points to calculate the distance between

Returns
-------
The square distance between the two points
"""
return (coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2

def log_p(
self,
data_position: np.array,
model_position: np.array,
sigma: float,
) -> float:
"""
Compute the log probability of a given model coordinate explaining
a given observed coordinate. Accounts for noise, with noiser image
coordinates having a comparatively lower log probability.

Parameters
----------
data_position
The observed coordinate
model_position
The model coordinate
sigma
The noise associated with the observed coordinate

Returns
-------
The log probability of the model coordinate explaining the observed coordinate
"""
chi2 = self.square_distance(data_position, model_position) / sigma**2
return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2

def log_likelihood(self) -> float:
"""
Compute the log likelihood of the model image coordinates explaining the observed image coordinates.

This is the sum across all permutations of the observed image coordinates of the log probability of each
model image coordinate explaining the observed image coordinate.

For example, if there are two observed image coordinates and two model image coordinates, the log likelihood
is the sum of the log probabilities:

P(data_0 | model_0) * P(data_1 | model_1)
+ P(data_0 | model_1) * P(data_1 | model_0)
+ P(data_0 | model_0) * P(data_1 | model_0)
+ P(data_0 | model_1) * P(data_1 | model_1)

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
n_non_nan_model_positions = np.count_nonzero(
~np.isnan(
self.model_positions,
).any(axis=1)
)
n_permutations = n_non_nan_model_positions ** len(self.data)
return -np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods())

def all_permutations_log_likelihoods(self) -> np.array:
"""
Compute the log likelihood for each permutation whereby the model could explain the observed image coordinates.

For example, if there are two observed image coordinates and two model image coordinates, the log likelihood
for each permutation is:

P(data_0 | model_0) * P(data_1 | model_1)
P(data_0 | model_1) * P(data_1 | model_0)
P(data_0 | model_0) * P(data_1 | model_0)
P(data_0 | model_1) * P(data_1 | model_1)

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
return np.array(
[
np.log(
np.sum(
[
np.exp(
self.log_p(
data_position,
model_position,
sigma,
)
)
for model_position in self.model_positions
if not np.isnan(model_position).any()
]
)
)
for data_position, sigma in zip(self.data, self.noise_map)
]
)
2 changes: 1 addition & 1 deletion test_autolens/point/model/test_analysis_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
directory = path.dirname(path.realpath(__file__))


def test__make_result__result_imaging_is_returned(point_dataset):
def _test__make_result__result_imaging_is_returned(point_dataset):
model = af.Collection(
galaxies=af.Collection(
lens=al.Galaxy(redshift=0.5, point_0=al.ps.Point(centre=(0.0, 0.0)))
Expand Down
104 changes: 104 additions & 0 deletions test_autolens/point/model/test_andrew_implementation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
try:
import jax

JAX_INSTALLED = True
except ImportError:
JAX_INSTALLED = False

import numpy as np
import pytest

from autolens.point.fit.positions.image.pair_repeat import Fit


@pytest.fixture
def data():
return np.array([(0.0, 0.0), (1.0, 0.0)])


@pytest.fixture
def noise_map():
return np.array([1.0, 1.0])


@pytest.fixture
def fit(data, noise_map):
model_positions = np.array(
[
(-1.0749, -1.1),
(1.19117, 1.175),
]
)

return Fit(
data=data,
noise_map=noise_map,
model_positions=model_positions,
)


def test_andrew_implementation(fit):
assert np.allclose(
fit.all_permutations_log_likelihoods(),
[
-1.51114426,
-1.50631469,
],
)
assert fit.log_likelihood() == -4.40375330990644


@pytest.mark.skipif(not JAX_INSTALLED, reason="JAX is not installed")
def test_jax(fit):
assert jax.jit(fit.log_likelihood)() == -4.40375330990644


def test_nan_model_positions(
data,
noise_map,
):
model_positions = np.array(
[
(-1.0749, -1.1),
(1.19117, 1.175),
(np.nan, np.nan),
]
)
fit = Fit(
data=data,
noise_map=noise_map,
model_positions=model_positions,
)

assert np.allclose(
fit.all_permutations_log_likelihoods(),
[
-1.51114426,
-1.50631469,
],
)
assert fit.log_likelihood() == -4.40375330990644


def test_duplicate_model_position(
data,
noise_map,
):
model_positions = np.array(
[
(-1.0749, -1.1),
(1.19117, 1.175),
(1.19117, 1.175),
]
)
fit = Fit(
data=data,
noise_map=noise_map,
model_positions=model_positions,
)

assert np.allclose(
fit.all_permutations_log_likelihoods(),
[-1.14237812, -0.87193683],
)
assert fit.log_likelihood() == -4.211539531047171