Skip to content

Commit ce7f6ab

Browse files
committed
BigMath.erf and BigMath.erfc with bit burst algorithm
1 parent 57cdef6 commit ce7f6ab

2 files changed

Lines changed: 235 additions & 32 deletions

File tree

lib/bigdecimal/math.rb

Lines changed: 226 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -608,16 +608,11 @@ def erf(x, prec)
608608
xf = x.to_f
609609
log10_erfc = -xf ** 2 / Math.log(10) - Math.log10(xf * Math::PI ** 0.5)
610610
erfc_prec = [prec + log10_erfc.ceil, 1].max
611-
erfc = _erfc_asymptotic(x, erfc_prec)
611+
erfc = _erfc_bit_burst(x, erfc_prec + BigDecimal::Internal::EXTRA_PREC)
612612
return BigDecimal(1).sub(erfc, prec) if erfc
613613
end
614614

615-
prec2 = prec + BigDecimal::Internal::EXTRA_PREC
616-
x_smallprec = x.mult(1, Integer.sqrt(prec2) / 2)
617-
# Taylor series of x with small precision is fast
618-
erf1 = _erf_taylor(x_smallprec, BigDecimal(0), BigDecimal(0), prec2)
619-
# Taylor series converges quickly for small x
620-
_erf_taylor(x - x_smallprec, x_smallprec, erf1, prec2).mult(1, prec)
615+
_erf_bit_burst(x, prec + BigDecimal::Internal::EXTRA_PREC).mult(1, prec)
621616
end
622617

623618
# call-seq:
@@ -640,20 +635,81 @@ def erfc(x, prec)
640635
return BigDecimal(0) if x > 5000000000 # erfc(5000000000) < 1e-10000000000000000000 (underflow)
641636

642637
if x >= 8
643-
y = _erfc_asymptotic(x, prec)
638+
y = _erfc_bit_burst(x, prec + BigDecimal::Internal::EXTRA_PREC)
644639
return y.mult(1, prec) if y
645640
end
646641

647642
# erfc(x) = 1 - erf(x) < exp(-x**2)/x/sqrt(pi)
648643
# Precision of erf(x) needs about log10(exp(-x**2)) extra digits
649644
log10 = 2.302585092994046
650645
high_prec = prec + BigDecimal::Internal::EXTRA_PREC + (x.ceil**2 / log10).ceil
651-
BigDecimal(1).sub(erf(x, high_prec), prec)
646+
BigDecimal(1).sub(_erf_bit_burst(x, high_prec), prec)
647+
end
648+
649+
# Calculates erf(x) using bit-burst algorithm.
650+
private_class_method def _erf_bit_burst(x, prec) # :nodoc:
651+
x = BigDecimal::Internal.coerce_to_bigdecimal(x, prec, :erf)
652+
prec = BigDecimal::Internal.coerce_validate_prec(prec, :erf)
653+
654+
return BigDecimal(0) if x > 5000000000 # erfc underflows
655+
x = x.mult(1, [prec - (x.ceil**2/Math.log(10)).floor, 1].max)
656+
657+
calculated_x = BigDecimal(0)
658+
erf_exp2 = BigDecimal(0)
659+
digits = 8
660+
scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec)
661+
662+
until x.zero?
663+
partial = x.truncate(digits)
664+
digits *= 2
665+
next if partial.zero?
666+
667+
erf_exp2 = _erf_exp2_binary_splitting(partial, calculated_x, erf_exp2, prec)
668+
calculated_x += partial
669+
x -= partial
670+
end
671+
erf_exp2.mult(scale, prec)
672+
end
673+
674+
# Calculates erfc(x) using bit-burst algorithm.
675+
private_class_method def _erfc_bit_burst(x, prec) # :nodoc:
676+
digits = (x.exponent + 1) * 40
677+
678+
calculated_x = x.truncate(digits)
679+
f = _erfc_exp2_asymptotic_binary_splitting(calculated_x, prec)
680+
return unless f
681+
682+
scale = 2 * exp(-x.mult(x, prec), prec).div(PI(prec).sqrt(prec), prec)
683+
x -= calculated_x
684+
685+
until x.zero?
686+
digits *= 2
687+
partial = x.truncate(digits)
688+
next if partial.zero?
689+
690+
f = _erfc_exp2_inv_inv_binary_splitting(partial, calculated_x, f, prec)
691+
calculated_x += partial
692+
x -= partial
693+
end
694+
f.mult(scale, prec)
695+
end
696+
697+
# Matrix multiplication for binary splitting method in erf/erfc calculation
698+
private_class_method def _bs_matrix_mult(m1, m2, size, prec) # :nodoc:
699+
(size * size).times.map do |i|
700+
size.times.map do |k|
701+
m1[i / size * size + k].mult(m2[size * k + i % size], prec)
702+
end.reduce {|a, b| a.add(b, prec) }
703+
end
704+
end
705+
706+
# Matrix/Vector weighted sum for binary splitting method in erf/erfc calculation
707+
private_class_method def _bs_weighted_sum(m1, w1, m2, w2, prec) # :nodoc:
708+
m1.zip(m2).map {|v1, v2| (v1 * w1).add(v2 * w2, prec) }
652709
end
653710

654-
# Calculates erf(x + a)
655-
private_class_method def _erf_taylor(x, a, erf_a, prec) # :nodoc:
656-
return erf_a if x.zero?
711+
# Calculates Taylor expansion of erf(x+a)*exp((x+a)**2)*sqrt(pi)/2 with binary splitting method.
712+
private_class_method def _erf_exp2_binary_splitting(x, a, f_a, prec) # :nodoc:
657713
# Let f(x+a) = erf(x+a)*exp((x+a)**2)*sqrt(pi)/2
658714
# = c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + ...
659715
# f'(x+a) = 1+2*(x+a)*f(x+a)
@@ -668,22 +724,64 @@ def erfc(x, prec)
668724
#
669725
# All coefficients are positive when a >= 0
670726

671-
scale = BigDecimal(2).div(sqrt(PI(prec), prec), prec)
672-
c_prev = erf_a.div(scale.mult(exp(-a*a, prec), prec), prec)
673-
c_next = (2 * a * c_prev).add(1, prec).mult(x, prec)
674-
sum = c_prev.add(c_next, prec)
727+
log10f = Math.log(10)
728+
cexponent = Math.log10([2 * a, Math.sqrt(2)].max.to_f) + BigDecimal::Internal.float_log(x.abs) / log10f
675729

676-
2.step do |k|
677-
cn = (c_prev.mult(x, prec) + a * c_next).mult(2, prec).mult(x, prec).div(k, prec)
678-
sum = sum.add(cn, prec)
679-
c_prev, c_next = c_next, cn
680-
break if [c_prev, c_next].all? { |c| c.zero? || (c.exponent < sum.exponent - prec) }
730+
steps = BigDecimal.save_exception_mode do
731+
BigDecimal.mode(BigDecimal::EXCEPTION_UNDERFLOW, false)
732+
(2..).bsearch do |n|
733+
x.to_f ** 2 < n && n * cexponent + Math.lgamma(n / 2)[0] / log10f + n * Math.log10(2) - Math.lgamma(n - 1)[0] / log10f < -prec + x.to_f**2 / log10f
734+
end
681735
end
682-
value = sum.mult(scale.mult(exp(-(x + a).mult(x + a, prec), prec), prec), prec)
683-
value > 1 ? BigDecimal(1) : value
736+
737+
if a == 0
738+
# Simple calculation for special case
739+
denominators = (steps / 2).times.map {|i| 2 * i + 3 }
740+
return x.mult(1 + BigDecimal::Internal.taylor_sum_binary_splitting(2 * x * x, denominators, prec), prec)
741+
end
742+
743+
# First, calculate a matrix that represents the sum of the Taylor series:
744+
# SumMatrix = (((((...+I)x*M4+I)*x*M3+I)*M2*x+I)*M1*x+I)
745+
# Where Mi is a 2x2 matrix that generates the next coefficients of Taylor series:
746+
# Vector(c4, c5) = M4*M3*M2*M1*Vector(c0, c1)
747+
# And then calculates:
748+
# SumMatrix * Vector(c0, c1) = Vector(c0+c1*x+c2*x**2+..., _)
749+
# In this binary splitting method, adjacent two operations are combined into one repeatedly.
750+
# ((...) * x * A + B) / C is the form of each operation. A and B are 2x2 matrices, C is a scalar.
751+
zero = BigDecimal(0)
752+
two = BigDecimal(2)
753+
two_a = two * a
754+
operations = steps.times.map do |i|
755+
n = BigDecimal(2 + i)
756+
[[zero, n, two, two_a], [n, zero, zero, n], n]
757+
end
758+
759+
while operations.size > 1
760+
xpow = xpow ? xpow.mult(xpow, prec) : x.mult(1, prec)
761+
operations = operations.each_slice(2).map do |op1, op2|
762+
# Combine two operations into one:
763+
# (((Remaining * x * A2 + B2) / C2) * x * A1 + B1) / C1
764+
# ((Remaining * (x*x) * (A2*A1) + (x*B2*A1+B1*C2)) / (C1*C2)
765+
# Therefore, combined operation can be represented as:
766+
# Anext = A2 * A1
767+
# Bnext = x * B2 * A1 + B1 * C2
768+
# Cnext = C1 * C2
769+
# xnext = x * x
770+
a1, b1, c1 = op1
771+
a2, b2, c2 = op2 || [[zero] * 4, [zero] * 4, BigDecimal(1)]
772+
[
773+
_bs_matrix_mult(a2, a1, 2, prec),
774+
_bs_weighted_sum(_bs_matrix_mult(b2, a1, 2, prec), xpow, b1, c2, prec),
775+
c1.mult(c2, prec),
776+
]
777+
end
778+
end
779+
_, sum_matrix, denominator = operations.first
780+
(sum_matrix[1] + f_a * (2 * a * sum_matrix[1] + sum_matrix[0])).div(denominator, prec)
684781
end
685782

686-
private_class_method def _erfc_asymptotic(x, prec) # :nodoc:
783+
# Calculates asymptotic expansion of erfc(x)*exp(x**2)*sqrt(pi)/2 with binary splitting method
784+
private_class_method def _erfc_exp2_asymptotic_binary_splitting(x, prec) # :nodoc:
687785
# Let f(x) = erfc(x)*sqrt(pi)*exp(x**2)/2
688786
# f(x) satisfies the following differential equation:
689787
# 2*x*f(x) = f'(x) + 1
@@ -696,21 +794,117 @@ def erfc(x, prec)
696794
# Using Stirling's approximation, we can simplify this condition to:
697795
# sqrt(2)/2 + k*log(k) - k - 2*k*log(x) < -prec*log(10)
698796
# and the left side is minimized when k = x**2.
699-
prec += BigDecimal::Internal::EXTRA_PREC
700797
xf = x.to_f
701798
kmax = (1..(xf ** 2).floor).bsearch do |k|
702799
Math.log(2) / 2 + k * Math.log(k) - k - 2 * k * Math.log(xf) < -prec * Math.log(10)
703800
end
704801
return unless kmax
705802

706-
sum = BigDecimal(1)
707-
x2 = x.mult(x, prec)
708-
d = BigDecimal(1)
709-
(1..kmax).each do |k|
710-
d = d.div(x2, prec).mult(1 - 2 * k, prec).div(2, prec)
711-
sum = sum.add(d, prec)
803+
# Convert asymptotic expansion to nested form:
804+
# 1 + a/x + a*b/x/x + a*b*c/x/x/x + a*b*c/x/x/x*rest
805+
# = 1 + (a/x) * (1 + (b/x) * (1 + (c/x) * (1 + rest)))
806+
#
807+
# And calculate it with binary splitting:
808+
# (a1/d + b1/d * (a2/d + b2/d * (rest)))
809+
# = ((a1*d+b1*a2)/(d*d) + b1*b2/(d*denominator) * (rest)))
810+
denominator = x.mult(x, prec).mult(2, prec)
811+
fractions = (1..kmax).map do |k|
812+
[denominator, BigDecimal(1 - 2 * k)]
813+
end
814+
while fractions.size > 1
815+
fractions = fractions.each_slice(2).map do |fraction1, fraction2|
816+
a1, b1 = fraction1
817+
a2, b2 = fraction2 || [BigDecimal(0), denominator]
818+
[
819+
a1.mult(denominator, prec).add(b1.mult(a2, prec), prec),
820+
b1.mult(b2, prec),
821+
]
822+
end
823+
denominator = denominator.mult(denominator, prec)
824+
end
825+
sum = fractions[0][0].add(fractions[0][1], prec).div(denominator, prec)
826+
sum.div(x, prec) / 2
827+
end
828+
829+
# Calculates f(1/(a+x)) where f(x) = (sqrt(pi)/2) * exp(1/x**2) * erfc(1/x)
830+
# Parameter f_inva is f(1/a)
831+
private_class_method def _erfc_exp2_inv_inv_binary_splitting(x, a, f_inva, prec) # :nodoc:
832+
return f_inva if x.zero?
833+
834+
# Performs taylor expansion using f(1/(a+x)) = f(1/a - x/(a*(a+x)))
835+
836+
# f(x) satisfies the following differential equation:
837+
# (1/a+w)**3*f'(1/a+w) + 2*f(1/a+w) = 1/a + w
838+
# From the above equation, we can derive the following Taylor expansion of f around 1/a:
839+
# Coefficients: f(1/a + w) = c0 + c1*w + c2*w**2 + c3*w**3 + ...
840+
# Constraints:
841+
# (w**3 + 3*w**2/a + 3*w/a**2 + 1/a**3) * (c1 + 2*c2*w + 3*c3*w**2 + 4*c4*w**3 + ...)
842+
# + 2 * (c0 + c1*w + c2*w**2 + c3*w**3 + ...) = 1/a + w
843+
# Recurrence relations:
844+
# c0 = f(1/a)
845+
# c1 = a**2 - 2*c0*a**3
846+
# c2 = (a**3 - 3*c1*a - 2*c1*a**3) / 2
847+
# c3 = -(3*c1*a**2 + 6*c2*a + 2*c2*a**3) / 3
848+
# c(n) = -((n-3)*c(n-3)*a**3 + 3*(n-2)*c(n-2)*a**2 + 3*(n-1)*c(n-1)*a + 2*c(n-1)*a**3) / n
849+
850+
aa = a.mult(a, prec)
851+
aaa = aa.mult(a, prec)
852+
c0 = f_inva
853+
c1 = (aa - 2 * c0 * aaa).mult(1, prec)
854+
c2 = (aaa - 3 * c1 * a - 2 * c1 * aaa).div(2, prec)
855+
856+
# Estimate the number of steps needed to achieve the required precision
857+
low_prec = 16
858+
w = x.div(a.mult(a + x, low_prec), low_prec)
859+
wpow = w.mult(w, low_prec)
860+
cm3, cm2, cm1 = [c0, c1, c2].map {|v| v.mult(1, low_prec) }
861+
a_low, aa_low, aaa_low = [a, aa, aaa].map {|v| v.mult(1, low_prec) }
862+
step = (3..).find do |n|
863+
wpow = wpow.mult(w, low_prec)
864+
cn = -((n - 3) * cm3 * aaa_low + 3 * aa_low * (n - 2) * cm2 + 3 * a_low * (n - 1) * cm1 + 2 * cm1 * aaa_low).div(n, low_prec)
865+
cm3, cm2, cm1 = cm2, cm1, cn
866+
cn.mult(wpow, low_prec).exponent < -prec
867+
end
868+
869+
# Let M(n) be a 3x3 matrix that transforms (c(n),c(n+1),c(n+2)) to (c(n-1),c(n),c(n+1))
870+
# Mn = | 0 1 0 |
871+
# | 0 0 1 |
872+
# | -(n-3)*aaa/n -3*(n-2)*aa/n -2*aaa-3*(n-1)*a/n |
873+
# Vector(c6,c7,c8) = M6*M5*M4*M3*M2*M1 * Vector(c0,c1,c2)
874+
# Vector(c0+c1*y/z+c2*(y/z)**2+..., _, _) = (((... + I)*M3*y/z + I)*M2*y/z + I)*M1*y/z + I) * Vector(c2, c1, c0)
875+
# Perform binary splitting on this nested parenthesized calculation by using the following formula:
876+
# (((...)*A2*y/z + B2)/D2 * A1*y/z + B1)/D1 = (((...)*(A2*A1)*(y*y)/z + (B2*A1*y+z*D2*B1)) / (D1*D2*z)
877+
# where A_n, Bn are matrices and Dn are scalars
878+
879+
zero = BigDecimal(0)
880+
operations = (3..step + 2).map do |n|
881+
bign = BigDecimal(n)
882+
[
883+
[
884+
zero, bign, zero,
885+
zero, zero, bign,
886+
BigDecimal(-(n - 3) * aaa), -3 * (n - 2) * aa, -2 * aaa - 3 * (n - 1) * a
887+
],
888+
[bign, zero, zero, zero, bign, zero, zero, zero, bign],
889+
bign
890+
]
891+
end
892+
893+
z = a.mult(a + x, prec)
894+
while operations.size > 1
895+
y = y ? y.mult(y, prec) : -x.mult(1, prec)
896+
operations = operations.each_slice(2).map do |op1, op2|
897+
a1, b1, d1 = op1
898+
a2, b2, d2 = op2 || [[zero] * 9, [zero] * 9, BigDecimal(1)]
899+
[
900+
_bs_matrix_mult(a2, a1, 3, prec),
901+
_bs_weighted_sum(_bs_matrix_mult(b2, a1, 3, prec), y, b1, d2.mult(z, prec), prec),
902+
d1.mult(d2, prec).mult(z, prec),
903+
]
904+
end
712905
end
713-
sum.div(exp(x2, prec).mult(PI(prec).sqrt(prec), prec), prec).div(x, prec)
906+
_, sum_matrix, denominator = operations[0]
907+
(sum_matrix[0] * c0 + sum_matrix[1] * c1 + sum_matrix[2] * c2).div(denominator, prec)
714908
end
715909

716910
# call-seq:

test/bigdecimal/test_bigmath.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,15 @@ def test_erfc
547547
assert_converge_in_precision {|n| BigMath.erfc(BigDecimal(20.5), n) }
548548
end
549549

550+
def test_erf_erfc_consistency_large_prec
551+
[BigDecimal(34.5), 34 + BigDecimal(4).div(7, 1200)].each do |x|
552+
erf = BigMath.erf(x, 1200) # Calculated with taylor series of erf
553+
erfc = BigMath.erfc(x, 400) # Calculated with asymptotic expansion
554+
erfc2 = 1 - erf
555+
assert_equal(erfc, erfc2.mult(1, 400))
556+
end
557+
end
558+
550559
def test_gamma
551560
[-1.8, -0.7, 0.6, 1.5, 2.4].each do |x|
552561
assert_in_epsilon(Math.gamma(x), gamma(BigDecimal(x.to_s), N))

0 commit comments

Comments
 (0)