Skip to content

zeta_from made jax compatible, zeta_from and wofz in higher precision…#272

Merged
NiekWielders merged 3 commits intomainfrom
feature/gaussian_mass
Feb 6, 2026
Merged

zeta_from made jax compatible, zeta_from and wofz in higher precision…#272
NiekWielders merged 3 commits intomainfrom
feature/gaussian_mass

Conversation

@NiekWielders
Copy link
Copy Markdown
Collaborator

Remove wofz distinction np and jnp to use of custom wofz for both. zeta_from now works for jnp. zeta_from and wofz now have higher accuracy.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Gaussian stellar mass profile’s special-function path to be JAX-compatible and to use a single custom wofz implementation (instead of branching to SciPy for NumPy), aiming for higher numerical precision.

Changes:

  • Refactored zeta_from to use xp-array logic throughout (intended to support jax.numpy) and removed the SciPy wofz branch.
  • Updated wofz to enforce higher-precision dtypes (float64 / complex128) and reorganized constants/region logic.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 18 to +52
deflections = mp.deflections_2d_via_analytic_from(
grid=ag.Grid2DIrregular([[1.0, 0.0]])
)

assert deflections[0, 0] == pytest.approx(1.024423, 1.0e-4)
assert deflections[0, 1] == pytest.approx(0.0, abs=1.0e-4)

deflections = mp.deflections_2d_via_analytic_from(
grid=ag.Grid2DIrregular([[-1.0, 0.0]])
)

assert deflections[0, 0] == pytest.approx(-1.024423, 1.0e-4)
assert deflections[0, 1] == pytest.approx(0.0, abs=1.0e-4)

mp = ag.mp.Gaussian(
centre=(0.0, 0.0),
ell_comps=(0.0, 0.111111),
intensity=1.0,
sigma=5.0,
mass_to_light_ratio=1.0,
)

deflections = mp.deflections_2d_via_analytic_from(
grid=ag.Grid2DIrregular([[0.5, 0.2]])
)

assert deflections[0, 0] == pytest.approx(0.554062, 1.0e-4)
assert deflections[0, 1] == pytest.approx(0.177336, 1.0e-4)

deflections = mp.deflections_2d_via_analytic_from(
grid=ag.Grid2DIrregular([[-0.5, -0.2]])
)

assert deflections[0, 0] == pytest.approx(-0.554062, 1.0e-4)
assert deflections[0, 1] == pytest.approx(-0.177336, 1.0e-4)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description says zeta_from is now JAX-compatible, but this test file still only exercises the NumPy path (no jax.numpy / xp=jnp coverage). Consider adding a small regression test that evaluates deflections_2d_via_analytic_from(..., xp=jnp) and asserts it matches the NumPy result, to prevent future JAX-only breakages of zeta_from / wofz.

Copilot uses AI. Check for mistakes.
Comment on lines +233 to 266
sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64)
inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64)

# ---------- Large-|z| continued fraction ----------
r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64)

t = z
for coef in r1_s1:
t = z - coef / t
for c in r1_s1:
t = z - c / t

w_large = 1j / (t * sqrt_pi)
w_large = 1j * inv_sqrt_pi / t

# --- Region 5: special small-imaginary case ---
U5 = xp.array([1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31])
V5 = xp.array(
[1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6]
)
# ---------- Region 5 ----------
U5 = xp.asarray([1.320522, 35.7668, 219.031,
1540.787, 3321.990, 36183.31], dtype=xp.float64)
V5 = xp.asarray([1.841439, 61.57037, 364.2191,
2186.181, 9022.228, 24322.84, 32066.6], dtype=xp.float64)

t = 1 / sqrt_pi
t = inv_sqrt_pi
for u in U5:
t = u + z2 * t

s = 1.0
s = xp.asarray(1.0, dtype=xp.float64)
for v in V5:
s = v + z2 * s

w5 = xp.exp(-z2) + 1j * z * t / s

# --- Region 6: remaining small-|z| region ---
U6 = xp.array([5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793])
V6 = xp.array(
[
10.479857,
53.992907,
170.35400,
348.70392,
457.33448,
352.73063,
122.60793,
]
)
# ---------- Region 6 ----------
U6 = xp.asarray([5.9126262, 30.180142, 93.15558,
181.92853, 214.38239, 122.60793], dtype=xp.float64)
V6 = xp.asarray([10.479857, 53.992907, 170.35400,
348.70392, 457.33448, 352.73063, 122.60793], dtype=xp.float64)

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wofz() allocates multiple coefficient arrays (r1_s1, U5, V5, U6, V6) on every call. Since zeta_from() calls wofz() twice per evaluation, this can add avoidable overhead for large grids. Consider hoisting these coefficients to module- or class-level constants (e.g. NumPy arrays) and only casting to xp as needed, to reduce per-call allocations.

Copilot uses AI. Check for mistakes.
@NiekWielders NiekWielders merged commit b29a03d into main Feb 6, 2026
14 checks passed
@Jammy2211 Jammy2211 deleted the feature/gaussian_mass branch February 13, 2026 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants