zeta_from made jax compatible, zeta_from and wofz in higher precision…#272
zeta_from made jax compatible, zeta_from and wofz in higher precision…#272NiekWielders merged 3 commits intomainfrom
Conversation
…, remove import of scipy wofz
There was a problem hiding this comment.
Pull request overview
This PR updates the Gaussian stellar mass profile’s special-function path to be JAX-compatible and to use a single custom wofz implementation (instead of branching to SciPy for NumPy), aiming for higher numerical precision.
Changes:
- Refactored
zeta_fromto usexp-array logic throughout (intended to supportjax.numpy) and removed the SciPywofzbranch. - Updated
wofzto enforce higher-precision dtypes (float64/complex128) and reorganized constants/region logic.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| deflections = mp.deflections_2d_via_analytic_from( | ||
| grid=ag.Grid2DIrregular([[1.0, 0.0]]) | ||
| ) | ||
|
|
||
| assert deflections[0, 0] == pytest.approx(1.024423, 1.0e-4) | ||
| assert deflections[0, 1] == pytest.approx(0.0, abs=1.0e-4) | ||
|
|
||
| deflections = mp.deflections_2d_via_analytic_from( | ||
| grid=ag.Grid2DIrregular([[-1.0, 0.0]]) | ||
| ) | ||
|
|
||
| assert deflections[0, 0] == pytest.approx(-1.024423, 1.0e-4) | ||
| assert deflections[0, 1] == pytest.approx(0.0, abs=1.0e-4) | ||
|
|
||
| mp = ag.mp.Gaussian( | ||
| centre=(0.0, 0.0), | ||
| ell_comps=(0.0, 0.111111), | ||
| intensity=1.0, | ||
| sigma=5.0, | ||
| mass_to_light_ratio=1.0, | ||
| ) | ||
|
|
||
| deflections = mp.deflections_2d_via_analytic_from( | ||
| grid=ag.Grid2DIrregular([[0.5, 0.2]]) | ||
| ) | ||
|
|
||
| assert deflections[0, 0] == pytest.approx(0.554062, 1.0e-4) | ||
| assert deflections[0, 1] == pytest.approx(0.177336, 1.0e-4) | ||
|
|
||
| deflections = mp.deflections_2d_via_analytic_from( | ||
| grid=ag.Grid2DIrregular([[-0.5, -0.2]]) | ||
| ) | ||
|
|
||
| assert deflections[0, 0] == pytest.approx(-0.554062, 1.0e-4) | ||
| assert deflections[0, 1] == pytest.approx(-0.177336, 1.0e-4) |
There was a problem hiding this comment.
The PR description says zeta_from is now JAX-compatible, but this test file still only exercises the NumPy path (no jax.numpy / xp=jnp coverage). Consider adding a small regression test that evaluates deflections_2d_via_analytic_from(..., xp=jnp) and asserts it matches the NumPy result, to prevent future JAX-only breakages of zeta_from / wofz.
| sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64) | ||
| inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64) | ||
|
|
||
| # ---------- Large-|z| continued fraction ---------- | ||
| r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64) | ||
|
|
||
| t = z | ||
| for coef in r1_s1: | ||
| t = z - coef / t | ||
| for c in r1_s1: | ||
| t = z - c / t | ||
|
|
||
| w_large = 1j / (t * sqrt_pi) | ||
| w_large = 1j * inv_sqrt_pi / t | ||
|
|
||
| # --- Region 5: special small-imaginary case --- | ||
| U5 = xp.array([1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31]) | ||
| V5 = xp.array( | ||
| [1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6] | ||
| ) | ||
| # ---------- Region 5 ---------- | ||
| U5 = xp.asarray([1.320522, 35.7668, 219.031, | ||
| 1540.787, 3321.990, 36183.31], dtype=xp.float64) | ||
| V5 = xp.asarray([1.841439, 61.57037, 364.2191, | ||
| 2186.181, 9022.228, 24322.84, 32066.6], dtype=xp.float64) | ||
|
|
||
| t = 1 / sqrt_pi | ||
| t = inv_sqrt_pi | ||
| for u in U5: | ||
| t = u + z2 * t | ||
|
|
||
| s = 1.0 | ||
| s = xp.asarray(1.0, dtype=xp.float64) | ||
| for v in V5: | ||
| s = v + z2 * s | ||
|
|
||
| w5 = xp.exp(-z2) + 1j * z * t / s | ||
|
|
||
| # --- Region 6: remaining small-|z| region --- | ||
| U6 = xp.array([5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793]) | ||
| V6 = xp.array( | ||
| [ | ||
| 10.479857, | ||
| 53.992907, | ||
| 170.35400, | ||
| 348.70392, | ||
| 457.33448, | ||
| 352.73063, | ||
| 122.60793, | ||
| ] | ||
| ) | ||
| # ---------- Region 6 ---------- | ||
| U6 = xp.asarray([5.9126262, 30.180142, 93.15558, | ||
| 181.92853, 214.38239, 122.60793], dtype=xp.float64) | ||
| V6 = xp.asarray([10.479857, 53.992907, 170.35400, | ||
| 348.70392, 457.33448, 352.73063, 122.60793], dtype=xp.float64) | ||
|
|
There was a problem hiding this comment.
wofz() allocates multiple coefficient arrays (r1_s1, U5, V5, U6, V6) on every call. Since zeta_from() calls wofz() twice per evaluation, this can add avoidable overhead for large grids. Consider hoisting these coefficients to module- or class-level constants (e.g. NumPy arrays) and only casting to xp as needed, to reduce per-call allocations.
Remove wofz distinction np and jnp to use of custom wofz for both. zeta_from now works for jnp. zeta_from and wofz now have higher accuracy.