Skip to content

Commit 40e8e07

Browse files
Merge pull request #276 from spf13/improve-string-float
fix: float string to number parsing
2 parents cb5df5f + fa4ea64 commit 40e8e07

File tree

4 files changed

+172
-34
lines changed

4 files changed

+172
-34
lines changed

cast.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package cast
99
import "time"
1010

1111
const errorMsg = "unable to cast %#v of type %T to %T"
12+
const errorMsgWith = "unable to cast %#v of type %T to %T: %w"
1213

1314
// Basic is a type parameter constraint for functions accepting basic types.
1415
//

number.go

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"encoding/json"
1010
"errors"
1111
"fmt"
12+
"regexp"
1213
"strconv"
1314
"strings"
1415
"time"
@@ -149,22 +150,22 @@ func toNumberE[T Number](i any, parseFn func(string) (T, error)) (T, error) {
149150
}
150151

151152
v, err := parseFn(s)
152-
if err == nil {
153-
return v, nil
153+
if err != nil {
154+
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
154155
}
155156

156-
return 0, fmt.Errorf(errorMsg, i, i, n)
157+
return v, nil
157158
case json.Number:
158159
if s == "" {
159160
return 0, nil
160161
}
161162

162163
v, err := parseFn(string(s))
163-
if err == nil {
164-
return v, nil
164+
if err != nil {
165+
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
165166
}
166167

167-
return 0, fmt.Errorf(errorMsg, i, i, n)
168+
return v, nil
168169
case float64EProvider:
169170
if _, ok := any(n).(float64); !ok {
170171
return 0, fmt.Errorf(errorMsg, i, i, n)
@@ -293,22 +294,22 @@ func toUnsignedNumberE[T Number](i any, parseFn func(string) (T, error)) (T, err
293294
}
294295

295296
v, err := parseFn(s)
296-
if err == nil {
297-
return v, nil
297+
if err != nil {
298+
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
298299
}
299300

300-
return 0, fmt.Errorf(errorMsg, i, i, n)
301+
return v, nil
301302
case json.Number:
302303
if s == "" {
303304
return 0, nil
304305
}
305306

306307
v, err := parseFn(string(s))
307-
if err == nil {
308-
return v, nil
308+
if err != nil {
309+
return 0, fmt.Errorf(errorMsgWith, i, i, n, err)
309310
}
310311

311-
return 0, fmt.Errorf(errorMsg, i, i, n)
312+
return v, nil
312313
case float64EProvider:
313314
if _, ok := any(n).(float64); !ok {
314315
return 0, fmt.Errorf(errorMsg, i, i, n)
@@ -413,7 +414,7 @@ func parseInt[T integer](s string) (T, error) {
413414
}
414415

415416
func parseUint[T unsigned](s string) (T, error) {
416-
v, err := strconv.ParseUint(trimDecimal(s), 0, 0)
417+
v, err := strconv.ParseUint(strings.TrimLeft(trimDecimal(s), "+"), 0, 0)
417418
if err != nil {
418419
return 0, err
419420
}
@@ -520,13 +521,28 @@ func trimZeroDecimal(s string) string {
520521
return s
521522
}
522523

523-
// trimming decimals seems significantly faster than parsing to float first
524-
//
525-
// see BenchmarkDecimal
524+
var stringNumberRe = regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`)
525+
526+
// see [BenchmarkDecimal] for details about the implementation
526527
func trimDecimal(s string) string {
527-
// trim the decimal part (if any)
528-
if i := strings.Index(s, "."); i >= 0 {
529-
s = s[:i]
528+
if !strings.Contains(s, ".") {
529+
return s
530+
}
531+
532+
matches := stringNumberRe.FindStringSubmatch(s)
533+
if matches != nil {
534+
// matches[1] is the captured integer part with sign
535+
s = matches[1]
536+
537+
// handle special cases
538+
switch s {
539+
case "-", "+":
540+
s += "0"
541+
case "":
542+
s = "0"
543+
}
544+
545+
return s
530546
}
531547

532548
return s

number_internal_test.go

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package cast
77

88
import (
9+
"regexp"
910
"strconv"
11+
"strings"
1012
"testing"
1113

1214
qt "github.com/frankban/quicktest"
@@ -45,42 +47,136 @@ func TestTrimZeroDecimal(t *testing.T) {
4547
}
4648

4749
func TestTrimDecimal(t *testing.T) {
48-
c := qt.New(t)
50+
testCases := []struct {
51+
input string
52+
expected string
53+
}{
54+
{"10.0", "10"},
55+
{"10.010", "10"},
56+
{"00000.00001", "00000"},
57+
{"-0001.0", "-0001"},
58+
{".5", "0"},
59+
{"+12.", "+12"},
60+
{"+.25", "+0"},
61+
{"-.25", "-0"},
62+
{"0.0000000000", "0"},
63+
{"0.0000000001", "0"},
64+
{"10.0000000000", "10"},
65+
{"10.0000000001", "10"},
66+
{"10000000000000.0000000000", "10000000000000"},
67+
68+
{"10...17", "10...17"},
69+
{"10.foobar", "10.foobar"},
70+
{"10.0i", "10.0i"},
71+
{"10.0E9", "10.0E9"},
72+
}
4973

50-
c.Assert(trimDecimal("10.0"), qt.Equals, "10")
51-
c.Assert(trimDecimal("10.00"), qt.Equals, "10")
52-
c.Assert(trimDecimal("10.010"), qt.Equals, "10")
53-
c.Assert(trimDecimal("0.0000000000"), qt.Equals, "0")
54-
c.Assert(trimDecimal("0.00000000001"), qt.Equals, "0")
74+
for _, testCase := range testCases {
75+
// TODO: remove after minimum Go version is >=1.22
76+
testCase := testCase
77+
78+
t.Run(testCase.input, func(t *testing.T) {
79+
c := qt.New(t)
80+
81+
c.Assert(trimDecimal(testCase.input), qt.Equals, testCase.expected)
82+
})
83+
}
5584
}
5685

86+
// Analysis (in the order of performance):
87+
//
88+
// - Trimming decimals based on decimal point yields a lot of incorrectly parsed values.
89+
// - Parsing to float might be better, but we still need to cast the number, it might overflow, problematic.
90+
// - Regex parsing is an order of magnitude slower, but it yields correct results.
5791
func BenchmarkDecimal(b *testing.B) {
58-
testCases := []string{"10.0", "10.00", "10.010", "0.0000000000", "0.0000000001", "10.0000000000", "10.0000000001", "10000000000000.0000000000"}
92+
testCases := []struct {
93+
input string
94+
expectError bool
95+
}{
96+
{"10.0", false},
97+
{"10.00", false},
98+
{"10.010", false},
99+
{"0.0000000000", false},
100+
{"0.0000000001", false},
101+
{"10.0000000000", false},
102+
{"10.0000000001", false},
103+
{"10000000000000.0000000000", false},
104+
105+
// {"10...17", true},
106+
// {"10.foobar", true},
107+
// {"10.0i", true},
108+
// {"10.0E9", true},
109+
}
110+
111+
trimDecimalString := func(s string) string {
112+
// trim the decimal part (if any)
113+
if i := strings.Index(s, "."); i >= 0 {
114+
s = s[:i]
115+
}
116+
117+
return s
118+
}
119+
120+
re := regexp.MustCompile(`^([-+]?\d*)(\.\d*)?$`)
121+
122+
trimDecimalRegex := func(s string) string {
123+
matches := re.FindStringSubmatch(s)
124+
if matches != nil {
125+
// matches[1] is the captured integer part with sign
126+
return matches[1]
127+
}
128+
129+
return s
130+
}
59131

60132
for _, testCase := range testCases {
61133
// TODO: remove after minimum Go version is >=1.22
62134
testCase := testCase
63135

64-
b.Run(testCase, func(b *testing.B) {
136+
b.Run(testCase.input, func(b *testing.B) {
65137
b.Run("ParseFloat", func(b *testing.B) {
66138
// TODO: use b.Loop() once updated to Go 1.24
67139
for i := 0; i < b.N; i++ {
68-
v, err := strconv.ParseFloat(testCase, 64)
69-
if err != nil {
70-
b.Fatal(err)
140+
v, err := strconv.ParseFloat(testCase.input, 64)
141+
if (err != nil) != testCase.expectError {
142+
if err != nil {
143+
b.Fatal(err)
144+
}
145+
146+
b.Fatal("expected error, but got none")
71147
}
72148

73149
n := int64(v)
74150
_ = n
75151
}
76152
})
77153

78-
b.Run("TrimDecimal", func(b *testing.B) {
154+
b.Run("TrimDecimalString", func(b *testing.B) {
155+
// TODO: use b.Loop() once updated to Go 1.24
156+
for i := 0; i < b.N; i++ {
157+
v, err := strconv.ParseInt(trimDecimalString(testCase.input), 0, 0)
158+
if (err != nil) != testCase.expectError {
159+
if err != nil {
160+
b.Fatal(err)
161+
}
162+
163+
b.Fatal("expected error, but got none")
164+
}
165+
166+
_ = v
167+
}
168+
})
169+
170+
b.Run("TrimDecimalRegex", func(b *testing.B) {
79171
// TODO: use b.Loop() once updated to Go 1.24
80172
for i := 0; i < b.N; i++ {
81-
v, err := strconv.ParseInt(trimDecimal(testCase), 0, 0)
82-
if err != nil {
83-
b.Fatal(err)
173+
v, err := strconv.ParseInt(trimDecimalRegex(testCase.input), 0, 0)
174+
if (err != nil) != testCase.expectError {
175+
if err != nil {
176+
b.Fatal(err)
177+
}
178+
179+
b.Fatal("expected error, but got none")
84180
}
85181

86182
_ = v

number_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ var numberContexts = map[string]numberContext{
150150
},
151151
}
152152

153+
// TODO: separate test and failure cases?
154+
// Kinda hard to track cases right now.
153155
func generateNumberTestCases(samples []any) []testCase {
154156
zero := samples[0]
155157
one := samples[1]
@@ -169,7 +171,9 @@ func generateNumberTestCases(samples []any) []testCase {
169171
_ = overflowString
170172

171173
kind := reflect.TypeOf(zero).Kind()
174+
isSint := kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64
172175
isUint := kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64
176+
isInt := isSint || isUint
173177

174178
// Some precision is lost when converting from float64 to float32.
175179
eightPoint31_32 := eightPoint31
@@ -231,6 +235,27 @@ func generateNumberTestCases(samples []any) []testCase {
231235
// Failure cases
232236
{"test", zero, true},
233237
{testing.T{}, zero, true},
238+
239+
{"10...17", zero, true},
240+
{"10.foobar", zero, true},
241+
{"10.0i", zero, true},
242+
}
243+
244+
if isInt {
245+
testCases = append(
246+
testCases,
247+
248+
testCase{".5", zero, false},
249+
testCase{"+8.", eight, false},
250+
testCase{"+.25", zero, false},
251+
testCase{"-.25", zero, isUint},
252+
253+
testCase{"10.0E9", zero, true},
254+
)
255+
} else if kind == reflect.Float32 {
256+
testCases = append(testCases, testCase{"10.0E9", float32(10000000000.000000), false})
257+
} else if kind == reflect.Float64 {
258+
testCases = append(testCases, testCase{"10.0E9", float64(10000000000.000000), false})
234259
}
235260

236261
if isUint && underflowString != nil {

0 commit comments

Comments
 (0)