fix(math): fix edge cases of ln and pow
This commit is contained in:
parent
0dd9bcdd7e
commit
2db41703b3
3 changed files with 136 additions and 9 deletions
|
|
@ -22,3 +22,4 @@ thiserror = "1.0.61"
|
|||
quickcheck = "1.0"
|
||||
quickcheck_macros = "1.0"
|
||||
rand = "0.8"
|
||||
proptest = "1.5"
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@
|
|||
# Cardano Math functions
|
||||
*/
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::{Div, Mul, Neg, Sub};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
pub type FixedDecimal = crate::math_malachite::Decimal;
|
||||
|
||||
pub static ZERO: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(0u64));
|
||||
pub static MINUS_ONE: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(-1i64));
|
||||
pub static ONE: Lazy<FixedDecimal> = Lazy::new(|| FixedDecimal::from(1u64));
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("error in regex")]
|
||||
|
|
@ -38,7 +42,7 @@ pub trait FixedPrecision:
|
|||
|
||||
/// Entry point for 'ln' approximation. First does the necessary scaling, and
|
||||
/// then calls the continued fraction calculation. For any value outside the
|
||||
/// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'.
|
||||
/// domain, i.e., 'x in (-inf,0]', the function panics.
|
||||
fn ln(&self) -> Self;
|
||||
|
||||
/// Entry point for 'pow' function. x^y = exp(y * ln x)
|
||||
|
|
@ -91,6 +95,9 @@ pub struct ExpCmpOrdering {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use malachite_base::num::arithmetic::traits::Abs;
|
||||
use proptest::prelude::Strategy;
|
||||
use proptest::proptest;
|
||||
use std::fs::File;
|
||||
use std::io::BufRead;
|
||||
use std::path::PathBuf;
|
||||
|
|
@ -479,4 +486,60 @@ mod tests {
|
|||
assert_eq!(res.iterations.to_string(), expected_iterations);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
|
||||
fn ln_of_0_should_be_undefined() {
|
||||
ZERO.ln();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ln of a value in (-inf,0] is undefined")]
|
||||
fn ln_of_negative_should_be_undefined() {
|
||||
MINUS_ONE.ln();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow_of_zero_to_any_positive_power_should_be_zero() {
|
||||
proptest!(|(y in 1u64..=u64::MAX)| {
|
||||
assert_eq!(ZERO.pow(&FixedDecimal::from(y)), *ZERO);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "zero to a negative power is undefined")]
|
||||
fn pow_of_zero_to_neg_power_should_be_undefined() {
|
||||
let y = FixedDecimal::from(-1i64);
|
||||
ZERO.pow(&y);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow_of_any_to_power_0_should_be_1() {
|
||||
proptest!(|(x in i64::MIN..=i64::MAX)| {
|
||||
assert_eq!(FixedDecimal::from(x).pow(&*ZERO), *ONE);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow_of_any_to_power_1_should_be_same() {
|
||||
proptest!(|(x in i64::MIN..=i64::MAX)| {
|
||||
assert_eq!(FixedDecimal::from(x).pow(&*ONE), FixedDecimal::from(x));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow_to_positive_times_pow_to_negative_should_be_1() {
|
||||
let epsilon = FixedDecimal::from_str("1000000000000000000", 34).unwrap();
|
||||
proptest!(|(x in (-5i64..=5i64).prop_filter("Exclude zero", |&x| x != 0), y in 1i64..=25i64)| {
|
||||
let x = FixedDecimal::from(x);
|
||||
let y = FixedDecimal::from(y);
|
||||
let minus_y = -&y;
|
||||
let x_to_y = x.pow(&y);
|
||||
let x_to_minus_y = x.pow(&minus_y);
|
||||
let result = &x_to_y * &x_to_minus_y;
|
||||
let diff = (&result - &*ONE).abs();
|
||||
// println!("x: {}, y: {}, x^y: {}, x^-y: {}, x^y * x^-y: {}, diff: {}", x, y, x_to_y, x_to_minus_y, result, diff);
|
||||
assert!(diff <= epsilon);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use malachite::num::basic::traits::One;
|
|||
use malachite::platform_64::Limb;
|
||||
use malachite::rounding_modes::RoundingMode;
|
||||
use malachite::{Integer, Natural};
|
||||
use malachite_base::num::arithmetic::traits::Sign;
|
||||
use malachite_base::num::arithmetic::traits::{Parity, Sign};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use std::cmp::Ordering;
|
||||
|
|
@ -135,6 +135,27 @@ impl<'a> Neg for &'a Decimal {
|
|||
}
|
||||
}
|
||||
|
||||
impl Abs for Decimal {
|
||||
type Output = Self;
|
||||
|
||||
fn abs(self) -> Self::Output {
|
||||
let mut result = Decimal::new(self.precision);
|
||||
result.data = self.data.abs();
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Abs for a reference to Decimal
|
||||
impl<'a> Abs for &'a Decimal {
|
||||
type Output = Decimal;
|
||||
|
||||
fn abs(self) -> Self::Output {
|
||||
let mut result = Decimal::new(self.precision);
|
||||
result.data = (&self.data).abs();
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for Decimal {
|
||||
type Output = Self;
|
||||
|
||||
|
|
@ -310,10 +331,15 @@ impl FixedPrecision for Decimal {
|
|||
|
||||
fn ln(&self) -> Self {
|
||||
let mut ln_x = Decimal::new(self.precision);
|
||||
ref_ln(&mut ln_x.data, &self.data);
|
||||
if ref_ln(&mut ln_x.data, &self.data) {
|
||||
ln_x
|
||||
} else {
|
||||
panic!("ln of a value in (-inf,0] is undefined")
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the power of a Decimal approximation using x^y = exp(y * ln x) formula
|
||||
/// While not exact, this is a more performant way to compute the power of a Decimal
|
||||
fn pow(&self, rhs: &Self) -> Self {
|
||||
let mut pow_x = Decimal::new(self.precision);
|
||||
ref_pow(&mut pow_x.data, &self.data, &rhs.data);
|
||||
|
|
@ -677,11 +703,48 @@ fn ref_ln(rop: &mut Integer, x: &Integer) -> bool {
|
|||
fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) {
|
||||
/* x^y = exp(y * ln x) */
|
||||
let mut tmp: Integer = Integer::from(0);
|
||||
ref_ln(&mut tmp, base);
|
||||
|
||||
if exponent == &ZERO.value || base == &ONE.value {
|
||||
// any base to the power of zero is one, or 1 to any power is 1
|
||||
*rop = ONE.value.clone();
|
||||
return;
|
||||
}
|
||||
if exponent == &ONE.value {
|
||||
// any base to the power of one is the base
|
||||
*rop = base.clone();
|
||||
return;
|
||||
}
|
||||
if base == &ZERO.value && exponent > &ZERO.value {
|
||||
// zero to any positive power is zero
|
||||
*rop = &ZERO.value * &PRECISION.value;
|
||||
return;
|
||||
}
|
||||
if base == &ZERO.value && exponent < &ZERO.value {
|
||||
panic!("zero to a negative power is undefined");
|
||||
}
|
||||
if base < &ZERO.value {
|
||||
// negate the base and calculate the power
|
||||
let neg_base = base.neg();
|
||||
let ref_ln = ref_ln(&mut tmp, &neg_base);
|
||||
debug_assert!(ref_ln);
|
||||
tmp *= exponent;
|
||||
scale(&mut tmp);
|
||||
let mut tmp_rop = Integer::from(0);
|
||||
ref_exp(&mut tmp_rop, &tmp);
|
||||
*rop = if (exponent / &PRECISION.value).even() {
|
||||
tmp_rop
|
||||
} else {
|
||||
-tmp_rop
|
||||
};
|
||||
} else {
|
||||
// base is positive, ref_ln result is valid
|
||||
let ref_ln = ref_ln(&mut tmp, base);
|
||||
debug_assert!(ref_ln);
|
||||
tmp *= exponent;
|
||||
scale(&mut tmp);
|
||||
ref_exp(rop, &tmp);
|
||||
}
|
||||
}
|
||||
|
||||
/// `bound_x` is the bound for exp in the interval x is chosen from
|
||||
/// `compare` the value to compare to
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue