Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ var (
// ErrInvalidBinaryData is returned when unmarshalling invalid binary data
// The binary data should follow the format as described in MarshalBinary
ErrInvalidBinaryData = fmt.Errorf("invalid binary data")

// ErrZeroPowNegative is returned when raising zero to a negative power
ErrZeroPowNegative = fmt.Errorf("can't raise zero to a negative power")
)

var (
Expand Down Expand Up @@ -1141,11 +1144,23 @@ func trailingZerosU128(n u128) uint8 {
return zeros
}

// PowInt returns d^e where e is an integer.
// Deprecated: Use PowInt32 instead for correct handling of 0^0 and negative exponents.
// This function treats 0 raised to any power as 0, which may not align with mathematical conventions
// but is practical in certain cases. See: https://github.com/quagmt/udecimal/issues/25.
//
// PowInt raises the decimal d to the integer power e (d^e).
//
// Special cases:
// - 0^e = 0 for any integer e
// - d^0 = 1 for any decimal d ≠ 0
//
// Examples:
//
// PowInt(2.5, 2) = 6.25
// PowInt(0, 0) = 0
// PowInt(0, 1) = 0
// PowInt(0, -1) = 0
// PowInt(2, 0) = 1
// PowInt(2.5, 2) = 6.25
// PowInt(2.5, -2) = 0.16
func (d Decimal) PowInt(e int) Decimal {
// check 0 first to avoid 0^0 = 1
Expand Down Expand Up @@ -1196,6 +1211,76 @@ func (d Decimal) PowInt(e int) Decimal {
return newDecimal(neg, bintFromBigInt(qBig), uint8(powPrecision))
}

// PowInt32 returns d raised to the power of e, where e is an int32.
//
// Returns:
//
// The result of d raised to the power of e.
// An error if d is zero and e is a negative integer.
//
// Special cases:
//
// 0^0 = 1
// 0^(any negative integer) results in an error
//
// Examples:
//
// PowInt32(0, 0) = 1
// PowInt32(2, 0) = 1
// PowInt32(0, 1) = 0
// PowInt32(0, -1) results in an error
// PowInt32(2.5, 2) = 6.25
// PowInt32(2.5, -2) = 0.16
func (d Decimal) PowInt32(e int32) (Decimal, error) {
// special case: 0 raised to a negative power
if d.coef.IsZero() && e < 0 {
return Decimal{}, ErrZeroPowNegative
}

if e == 0 {
return One, nil
}

if e == 1 {
return d, nil
}

// Rescale first to remove trailing zeros
dTrim := d.trimTrailingZeros()

if e < 0 {
return dTrim.powIntInverse(int(-e)), nil
}

// e > 1 && d != 0
q, err := dTrim.tryPowIntU128(int(e))
if err == nil {
return q, nil
}

// overflow, fallback to big.Int
dBig := dTrim.coef.GetBig()

var factor int32
powPrecision := int32(dTrim.prec) * e
if powPrecision >= int32(defaultPrec) {
factor = powPrecision - int32(defaultPrec)
powPrecision = int32(defaultPrec)
}

m := new(big.Int).Exp(bigTen, big.NewInt(int64(factor)), nil)
dBig = new(big.Int).Exp(dBig, big.NewInt(int64(e)), nil)
qBig := dBig.Quo(dBig, m)

neg := d.neg
if e%2 == 0 {
neg = false
}

//nolint:gosec
return newDecimal(neg, bintFromBigInt(qBig), uint8(powPrecision)), nil
}

// powIntInverse returns d^(-e), with e > 0
func (d Decimal) powIntInverse(e int) Decimal {
q, err := d.tryInversePowIntU128(e)
Expand Down Expand Up @@ -1279,8 +1364,8 @@ func (d Decimal) tryInversePowIntU128(e int) (Decimal, error) {
return Decimal{}, errOverflow
}

if d.coef.u128.hi != 0 && e >= 3 {
// e > 3 and u128.hi != 0 means the result will >= 2^192,
if d.coef.u128.hi != 0 && e >= 4 {
// e >= 4 and u128.hi != 0 means the result will >= 2^256,
// which we can't use fast division. So we need to use big.Int instead
return Decimal{}, errOverflow
}
Expand Down
127 changes: 127 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,133 @@ func TestRandomPow(t *testing.T) {
}
}

func TestPowInt32(t *testing.T) {
testcases := []struct {
a string
b int32
want string
wantErr error
}{
{"123456789012345678901234567890123456789.9999999999999999999", 2, "15241578753238836750495351562566681945252248135650053345652796829976527968319.753086421975308642", nil},
{"0.5", -14, "16384", nil},
{"5", -18, "0.000000000000262144", nil},
{"-96", 384, "155651563400161893689540829251750532876602528021691915200061141022544075854496838643052295888420136905906567539126502582243693732125449523059780613380755061052491943449381255863820131332142779769865996188291542971996702478765598563482106934995948481892528830806840727897892513634949541154348143236794203399068607458789100280733156671481421737413484548654754828937861442964361485155011834501441449057827522043722520499866143913624005535732240536689495728164138830318329923569260213567200238743687906030695515032990022513102670332644203639546984105586335760789206424524917450457774575904047665710191104154700220406574406611422191187238002842748820651406984670104474060413271629299557918370269495849383625416400964818595369246834495413046931303826618633216386400256", nil},
{"-70", -8, "0.0000000000000017346", nil},
{"0.12", 100, "0", nil},
{"0", 0, "1", nil},
{"0", -1, "0", ErrZeroPowNegative},
{"0", 1, "0", nil},
{"0", 10, "0", nil},
{"1.12345", 4, "1.5929971334827095062", nil},
{"123456789012345678901234567890123456789.9999999999999999999", 0, "1", nil},
{"123456789012345678901234567890123456789.9999999999999999999", 1, "123456789012345678901234567890123456789.9999999999999999999", nil},
{"1.5", 3, "3.375", nil},
{"1.12345", 1, "1.12345", nil},
{"1.12345", 2, "1.2621399025", nil},
{"1.12345", 3, "1.417951073463625", nil},
{"1.12345", 4, "1.5929971334827095062", nil},
{"1.12345", 5, "1.7896526296111499947", nil},
{"1.12345", 6, "2.0105852467366464616", nil},
{"1.12345", 7, "2.2587919954462854673", nil},
{"-1.12345", 4, "1.5929971334827095062", nil},
}

for _, tc := range testcases {
t.Run(fmt.Sprintf("%s.pow(%d)", tc.a, tc.b), func(t *testing.T) {
a, err := Parse(tc.a)
require.NoError(t, err)

aStr := a.String()

b, err := a.PowInt32(tc.b)
if tc.wantErr != nil {
require.Equal(t, tc.wantErr, err)
return
}

require.Equal(t, tc.want, b.String())

// make sure a is immutable
require.Equal(t, aStr, a.String())

// cross check with shopspring/decimal

aa := decimal.RequireFromString(tc.a)
aa, err = aa.PowWithPrecision(decimal.New(int64(tc.b), 0), int32(b.prec)+4)

// special case for 0^0
// udecimal: 0^0 = 1
// shopspring/decimal: 0^0 is undefined and will return an error
if tc.a == "0" && tc.b == 0 {
require.EqualError(t, err, "cannot represent undefined value of 0**0")
return
}

require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String())
})
}
}

func TestRandomPowInt32(t *testing.T) {
inputs := []string{
"0.1234",
"-0.1234",
"1.123456789012345679",
"-1.123456789012345679",
"1.12345",
"-1.12345",
"123456789012345678901234567890123456789.9999999999999999999",
"123456789012345678901234567890123456789.9999999999999999999",
"1.5",
"123456.789",
"123.4",
"1234567890123456789.1234567890123456789",
"-1234567890123456789.1234567890123456789",
}

for _, input := range inputs {
t.Run(fmt.Sprintf("pow(%s)", input), func(t *testing.T) {
a := MustParse(input)

for i := 0; i <= 1000; i++ {
b, err := a.PowInt32(int32(i))
require.NoError(t, err)

aa := decimal.RequireFromString(input)
aa, err = aa.PowWithPrecision(decimal.New(int64(i), 0), int32(b.prec)+4)
require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String(), "%s.pow(%d)", input, i)
}
})
}

for _, input := range inputs {
t.Run(fmt.Sprintf("powInverse(%s)", input), func(t *testing.T) {
a := MustParse(input)

for i := 0; i >= -100; i-- {
b, err := a.PowInt32(int32(i))
require.NoError(t, err)

aa := decimal.RequireFromString(input)
aa, err = aa.PowWithPrecision(decimal.New(int64(i), 0), int32(b.prec)+4)
require.NoError(t, err)

aa = aa.Truncate(int32(b.prec))

require.Equal(t, aa.String(), b.String(), "%s.pow(%d)", input, i)
}
})
}
}

func TestSqrt(t *testing.T) {
testcases := []struct {
a string
Expand Down
14 changes: 14 additions & 0 deletions doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,20 @@ func ExampleDecimal_PowInt() {
// 0.6609822195782933439
}

func ExampleDecimal_PowInt32() {
fmt.Println(MustParse("1.23").PowInt32(2))
fmt.Println(MustParse("1.23").PowInt32(0))
fmt.Println(MustParse("1.23").PowInt32(-2))
fmt.Println(MustParse("0").PowInt32(0))
fmt.Println(MustParse("0").PowInt32(-2))
// Output:
// 1.5129 <nil>
// 1 <nil>
// 0.6609822195782933439 <nil>
// 1 <nil>
// 0 can't raise zero to a negative power
}

func ExampleDecimal_Prec() {
fmt.Println(MustParse("1.23").Prec())
// Output:
Expand Down
51 changes: 50 additions & 1 deletion fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ func FuzzTrunc(f *testing.F) {
})
}

func FuzzPowInt(f *testing.F) {
func FuzzDepcrecatedPowInt(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int())
}
Expand Down Expand Up @@ -605,6 +605,55 @@ func FuzzPowInt(f *testing.F) {
})
}

func FuzzPowInt32(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int())
}

f.Fuzz(func(t *testing.T, aneg bool, ahi uint64, alo uint64, aprec uint8, pow int) {
a, err := NewFromHiLo(aneg, ahi, alo, aprec)
if err == ErrPrecOutOfRange {
t.Skip()
} else {
require.NoError(t, err)
}

// use pow less than 10000
p := pow % 10000

c, err := a.PowInt32(int32(p))
if a.IsZero() && p < 0 {
require.Equal(t, err, ErrDivideByZero)
return
}

if c.coef.overflow() {
require.NotNil(t, c.coef.bigInt)
require.Equal(t, u128{}, c.coef.u128)
} else {
require.Nil(t, c.coef.bigInt)
}

// compare with shopspring/decimal
aa := ssDecimal(aneg, ahi, alo, aprec)
aa, err = aa.PowWithPrecision(ss.New(int64(p), 0), int32(c.prec)+4)

// special case for 0^0
// udecimal: 0^0 = 1
// shopspring/decimal: 0^0 is undefined and will return an error
if a.IsZero() && p == 0 {
require.EqualError(t, err, "cannot represent undefined value of 0**0")
require.Equal(t, "1", c.String())
return
}

require.NoError(t, err)
aa = aa.Truncate(int32(c.prec))

require.Equal(t, aa.String(), c.String(), "powInt %s %d", a, p)
})
}

func FuzzPowNegative(f *testing.F) {
for _, c := range corpus {
f.Add(c.neg, c.hi, c.lo, c.prec, rand.Int64())
Expand Down
Loading