From 2db41703b356ac28be38ea3b40ee6023d85204b0 Mon Sep 17 00:00:00 2001 From: Andrew Westberg Date: Sat, 26 Oct 2024 19:09:33 +0000 Subject: [PATCH] fix(math): fix edge cases of ln and pow --- pallas-math/Cargo.toml | 1 + pallas-math/src/math.rs | 67 ++++++++++++++++++++++++++- pallas-math/src/math_malachite.rs | 77 ++++++++++++++++++++++++++++--- 3 files changed, 136 insertions(+), 9 deletions(-) diff --git a/pallas-math/Cargo.toml b/pallas-math/Cargo.toml index e6d14be..2b4861d 100644 --- a/pallas-math/Cargo.toml +++ b/pallas-math/Cargo.toml @@ -22,3 +22,4 @@ thiserror = "1.0.61" quickcheck = "1.0" quickcheck_macros = "1.0" rand = "0.8" +proptest = "1.5" diff --git a/pallas-math/src/math.rs b/pallas-math/src/math.rs index 8f103b8..23ce33b 100644 --- a/pallas-math/src/math.rs +++ b/pallas-math/src/math.rs @@ -2,13 +2,17 @@ # Cardano Math functions */ +use once_cell::sync::Lazy; use std::fmt::{Debug, Display}; use std::ops::{Div, Mul, Neg, Sub}; - use thiserror::Error; pub type FixedDecimal = crate::math_malachite::Decimal; +pub static ZERO: Lazy = Lazy::new(|| FixedDecimal::from(0u64)); +pub static MINUS_ONE: Lazy = Lazy::new(|| FixedDecimal::from(-1i64)); +pub static ONE: Lazy = Lazy::new(|| FixedDecimal::from(1u64)); + #[derive(Debug, Error)] pub enum Error { #[error("error in regex")] @@ -38,7 +42,7 @@ pub trait FixedPrecision: /// Entry point for 'ln' approximation. First does the necessary scaling, and /// then calls the continued fraction calculation. For any value outside the - /// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'. + /// domain, i.e., 'x in (-inf,0]', the function panics. fn ln(&self) -> Self; /// Entry point for 'pow' function. x^y = exp(y * ln x) @@ -91,6 +95,9 @@ pub struct ExpCmpOrdering { #[cfg(test)] mod tests { use super::*; + use malachite_base::num::arithmetic::traits::Abs; + use proptest::prelude::Strategy; + use proptest::proptest; use std::fs::File; use std::io::BufRead; use std::path::PathBuf; @@ -479,4 +486,60 @@ mod tests { assert_eq!(res.iterations.to_string(), expected_iterations); } } + + #[test] + #[should_panic(expected = "ln of a value in (-inf,0] is undefined")] + fn ln_of_0_should_be_undefined() { + ZERO.ln(); + } + + #[test] + #[should_panic(expected = "ln of a value in (-inf,0] is undefined")] + fn ln_of_negative_should_be_undefined() { + MINUS_ONE.ln(); + } + + #[test] + fn pow_of_zero_to_any_positive_power_should_be_zero() { + proptest!(|(y in 1u64..=u64::MAX)| { + assert_eq!(ZERO.pow(&FixedDecimal::from(y)), *ZERO); + }); + } + + #[test] + #[should_panic(expected = "zero to a negative power is undefined")] + fn pow_of_zero_to_neg_power_should_be_undefined() { + let y = FixedDecimal::from(-1i64); + ZERO.pow(&y); + } + + #[test] + fn pow_of_any_to_power_0_should_be_1() { + proptest!(|(x in i64::MIN..=i64::MAX)| { + assert_eq!(FixedDecimal::from(x).pow(&*ZERO), *ONE); + }); + } + + #[test] + fn pow_of_any_to_power_1_should_be_same() { + proptest!(|(x in i64::MIN..=i64::MAX)| { + assert_eq!(FixedDecimal::from(x).pow(&*ONE), FixedDecimal::from(x)); + }); + } + + #[test] + fn pow_to_positive_times_pow_to_negative_should_be_1() { + let epsilon = FixedDecimal::from_str("1000000000000000000", 34).unwrap(); + proptest!(|(x in (-5i64..=5i64).prop_filter("Exclude zero", |&x| x != 0), y in 1i64..=25i64)| { + let x = FixedDecimal::from(x); + let y = FixedDecimal::from(y); + let minus_y = -&y; + let x_to_y = x.pow(&y); + let x_to_minus_y = x.pow(&minus_y); + let result = &x_to_y * &x_to_minus_y; + let diff = (&result - &*ONE).abs(); + // println!("x: {}, y: {}, x^y: {}, x^-y: {}, x^y * x^-y: {}, diff: {}", x, y, x_to_y, x_to_minus_y, result, diff); + assert!(diff <= epsilon); + }); + } } diff --git a/pallas-math/src/math_malachite.rs b/pallas-math/src/math_malachite.rs index 824c46e..545c62e 100644 --- a/pallas-math/src/math_malachite.rs +++ b/pallas-math/src/math_malachite.rs @@ -8,7 +8,7 @@ use malachite::num::basic::traits::One; use malachite::platform_64::Limb; use malachite::rounding_modes::RoundingMode; use malachite::{Integer, Natural}; -use malachite_base::num::arithmetic::traits::Sign; +use malachite_base::num::arithmetic::traits::{Parity, Sign}; use once_cell::sync::Lazy; use regex::Regex; use std::cmp::Ordering; @@ -135,6 +135,27 @@ impl<'a> Neg for &'a Decimal { } } +impl Abs for Decimal { + type Output = Self; + + fn abs(self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = self.data.abs(); + result + } +} + +// Implement Abs for a reference to Decimal +impl<'a> Abs for &'a Decimal { + type Output = Decimal; + + fn abs(self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = (&self.data).abs(); + result + } +} + impl Mul for Decimal { type Output = Self; @@ -310,10 +331,15 @@ impl FixedPrecision for Decimal { fn ln(&self) -> Self { let mut ln_x = Decimal::new(self.precision); - ref_ln(&mut ln_x.data, &self.data); - ln_x + if ref_ln(&mut ln_x.data, &self.data) { + ln_x + } else { + panic!("ln of a value in (-inf,0] is undefined") + } } + /// Compute the power of a Decimal approximation using x^y = exp(y * ln x) formula + /// While not exact, this is a more performant way to compute the power of a Decimal fn pow(&self, rhs: &Self) -> Self { let mut pow_x = Decimal::new(self.precision); ref_pow(&mut pow_x.data, &self.data, &rhs.data); @@ -677,10 +703,47 @@ fn ref_ln(rop: &mut Integer, x: &Integer) -> bool { fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) { /* x^y = exp(y * ln x) */ let mut tmp: Integer = Integer::from(0); - ref_ln(&mut tmp, base); - tmp *= exponent; - scale(&mut tmp); - ref_exp(rop, &tmp); + + if exponent == &ZERO.value || base == &ONE.value { + // any base to the power of zero is one, or 1 to any power is 1 + *rop = ONE.value.clone(); + return; + } + if exponent == &ONE.value { + // any base to the power of one is the base + *rop = base.clone(); + return; + } + if base == &ZERO.value && exponent > &ZERO.value { + // zero to any positive power is zero + *rop = &ZERO.value * &PRECISION.value; + return; + } + if base == &ZERO.value && exponent < &ZERO.value { + panic!("zero to a negative power is undefined"); + } + if base < &ZERO.value { + // negate the base and calculate the power + let neg_base = base.neg(); + let ref_ln = ref_ln(&mut tmp, &neg_base); + debug_assert!(ref_ln); + tmp *= exponent; + scale(&mut tmp); + let mut tmp_rop = Integer::from(0); + ref_exp(&mut tmp_rop, &tmp); + *rop = if (exponent / &PRECISION.value).even() { + tmp_rop + } else { + -tmp_rop + }; + } else { + // base is positive, ref_ln result is valid + let ref_ln = ref_ln(&mut tmp, base); + debug_assert!(ref_ln); + tmp *= exponent; + scale(&mut tmp); + ref_exp(rop, &tmp); + } } /// `bound_x` is the bound for exp in the interval x is chosen from