Skip to content

Commit 9c197bf

Browse files
fix: Columns and Values should recognize pointer values too (#67)
1 parent cbcbcfc commit 9c197bf

File tree

4 files changed

+132
-10
lines changed

4 files changed

+132
-10
lines changed

columns.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func columnNames(model reflect.Value, strict bool, excluded ...string) []string
121121
continue
122122
}
123123

124-
if supportedColumnType(valField.Kind()) || isValidSqlValue(valField) {
124+
if supportedColumnType(valField) || isValidSqlValue(valField) {
125125
names = append(names, fieldName)
126126
}
127127
}
@@ -152,13 +152,16 @@ func reflectValue(v interface{}) (reflect.Value, error) {
152152
return vVal, nil
153153
}
154154

155-
func supportedColumnType(k reflect.Kind) bool {
156-
switch k {
155+
func supportedColumnType(v reflect.Value) bool {
156+
switch v.Kind() {
157157
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
158158
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
159159
reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Interface,
160160
reflect.String:
161161
return true
162+
case reflect.Ptr:
163+
ptrVal := reflect.New(v.Type().Elem())
164+
return supportedColumnType(ptrVal.Elem())
162165
default:
163166
return false
164167
}
@@ -169,6 +172,11 @@ func isValidSqlValue(v reflect.Value) bool {
169172
// 1. It returns true for sql.driver's type check for types like time.Time
170173
// 2. It implements the driver.Valuer interface allowing conversion directly
171174
// into sql statements
175+
if v.Kind() == reflect.Ptr {
176+
ptrVal := reflect.New(v.Type().Elem())
177+
return isValidSqlValue(ptrVal.Elem())
178+
}
179+
172180
if driver.IsValue(v.Interface()) {
173181
return true
174182
}

columns_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,17 @@ func TestColumnsStoresOneCacheEntryPerInstance(t *testing.T) {
272272
assert.Equal(t, 1, after-before, "Cache size grew unexpectedly")
273273
}
274274

275-
func TestValuesWorkWithValidSqlValueTypes(t *testing.T) {
275+
func TestColumnsReturnsStructTagsWithPointers(t *testing.T) {
276+
type personUpdate struct {
277+
Name *string `db:"name"`
278+
}
279+
280+
cols, err := Columns(&personUpdate{})
281+
assert.NoError(t, err)
282+
assert.EqualValues(t, []string{"name"}, cols)
283+
}
284+
285+
func TestColumnsWorkWithValidSqlValueTypes(t *testing.T) {
276286
type coupon struct {
277287
Value int `db:"value"`
278288
Expires time.Time `db:"expires"`
@@ -284,6 +294,18 @@ func TestValuesWorkWithValidSqlValueTypes(t *testing.T) {
284294
assert.EqualValues(t, []string{"value", "expires"}, cols)
285295
}
286296

297+
func TestColumnsWorkWithPointerValidSqlTypes(t *testing.T) {
298+
type coupon struct {
299+
Value int `db:"value"`
300+
Expires *time.Time `db:"expires"`
301+
}
302+
303+
c := &coupon{}
304+
cols, err := Columns(c)
305+
assert.NoError(t, err)
306+
assert.EqualValues(t, []string{"value", "expires"}, cols)
307+
}
308+
287309
type Pet struct {
288310
Species string
289311
Name string
@@ -310,3 +332,7 @@ func BenchmarkColumnsLargeStruct(b *testing.B) {
310332
Columns(ls)
311333
}
312334
}
335+
336+
func ptr(s string) *string {
337+
return &s
338+
}

example_scanner_test.go

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ func exampleDB() *sql.DB {
3333
);`,
3434
`INSERT INTO person (id, name) VALUES (1, 'brett', 1);`,
3535
`INSERT INTO person (id, name) VALUES (2, 'fred', 1);`,
36+
`INSERT INTO person (id) VALUES (3);`,
3637
)
3738
}
3839

@@ -136,6 +137,53 @@ func ExampleRowStrict() {
136137
// {"ID":0,"Name":"brett"}
137138
}
138139

140+
func ExampleRowPtr() {
141+
db := exampleDB()
142+
defer db.Close()
143+
rows, err := db.Query("SELECT id,name FROM person where id = 3 LIMIT 1")
144+
if err != nil {
145+
panic(err)
146+
}
147+
148+
var person struct {
149+
ID int
150+
Name *string `db:"name"`
151+
}
152+
153+
err = scan.RowStrict(&person, rows)
154+
if err != nil {
155+
panic(err)
156+
}
157+
158+
json.NewEncoder(os.Stdout).Encode(&person)
159+
// Output:
160+
// {"ID":0,"Name":null}
161+
}
162+
163+
func ExampleRowPtrType() {
164+
db := exampleDB()
165+
defer db.Close()
166+
rows, err := db.Query("SELECT id,name FROM person where id = 3 LIMIT 1")
167+
if err != nil {
168+
panic(err)
169+
}
170+
171+
type NullableString *string
172+
var person struct {
173+
ID int
174+
Name NullableString `db:"name"`
175+
}
176+
177+
err = scan.RowStrict(&person, rows)
178+
if err != nil {
179+
panic(err)
180+
}
181+
182+
json.NewEncoder(os.Stdout).Encode(&person)
183+
// Output:
184+
// {"ID":0,"Name":null}
185+
}
186+
139187
func ExampleRow_scalar() {
140188
db := exampleDB()
141189
defer db.Close()
@@ -165,8 +213,8 @@ func ExampleRows() {
165213
}
166214

167215
var persons []struct {
168-
ID int `db:"id"`
169-
Name string `db:"name"`
216+
ID int `db:"id"`
217+
Name *string `db:"name"`
170218
}
171219

172220
err = scan.Rows(&persons, rows)
@@ -176,7 +224,7 @@ func ExampleRows() {
176224

177225
json.NewEncoder(os.Stdout).Encode(&persons)
178226
// Output:
179-
// [{"ID":1,"Name":"brett"},{"ID":2,"Name":"fred"}]
227+
// [{"ID":1,"Name":"brett"},{"ID":2,"Name":"fred"},{"ID":3,"Name":null}]
180228
}
181229

182230
func ExampleRowsStrict() {
@@ -189,7 +237,7 @@ func ExampleRowsStrict() {
189237

190238
var persons []struct {
191239
ID int
192-
Name string `db:"name"`
240+
Name *string `db:"name"`
193241
}
194242

195243
err = scan.Rows(&persons, rows)
@@ -199,13 +247,13 @@ func ExampleRowsStrict() {
199247

200248
json.NewEncoder(os.Stdout).Encode(&persons)
201249
// Output:
202-
// [{"ID":0,"Name":"brett"},{"ID":0,"Name":"fred"}]
250+
// [{"ID":0,"Name":"brett"},{"ID":0,"Name":"fred"},{"ID":0,"Name":null}]
203251
}
204252

205253
func ExampleRows_primitive() {
206254
db := exampleDB()
207255
defer db.Close()
208-
rows, err := db.Query("SELECT name FROM person ORDER BY id ASC")
256+
rows, err := db.Query("SELECT name FROM person WHERE name IS NOT NULL ORDER BY id ASC")
209257
if err != nil {
210258
panic(err)
211259
}

values_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,30 @@ func TestValuesScansDBTags(t *testing.T) {
3434
assert.EqualValues(t, []interface{}{"Brett"}, vals)
3535
}
3636

37+
func TestValuesScansPointerDBTags(t *testing.T) {
38+
type person struct {
39+
Name *string `db:"n"`
40+
}
41+
42+
p := &person{Name: ptr("Brett")}
43+
vals, err := Values([]string{"n"}, p)
44+
require.NoError(t, err)
45+
46+
assert.EqualValues(t, []interface{}{ptr("Brett")}, vals)
47+
}
48+
49+
func TestValuesReturnsNilPointers(t *testing.T) {
50+
type person struct {
51+
Name *string `db:"n"`
52+
}
53+
54+
p := &person{Name: nil}
55+
vals, err := Values([]string{"n"}, p)
56+
require.NoError(t, err)
57+
58+
assert.EqualValues(t, []interface{}{(*string)(nil)}, vals)
59+
}
60+
3761
func TestValuesScansNestedFields(t *testing.T) {
3862
type Address struct {
3963
Street string
@@ -124,6 +148,22 @@ func TestValuesValidSqlTypes(t *testing.T) {
124148
assert.EqualValues(t, []interface{}{25, tNow}, vals)
125149
}
126150

151+
func TestValuesValidPointerSqlTypes(t *testing.T) {
152+
tNow := time.Now()
153+
type coupon struct {
154+
Value int
155+
Expires *time.Time
156+
}
157+
c := &coupon{
158+
Value: 25,
159+
Expires: &tNow,
160+
}
161+
162+
vals, err := Values([]string{"Value", "Expires"}, c)
163+
require.NoError(t, err)
164+
assert.EqualValues(t, []interface{}{25, &tNow}, vals)
165+
}
166+
127167
func TestValuesDriverValuerImplementers(t *testing.T) {
128168
type person struct {
129169
Name string `db:"name"`

0 commit comments

Comments
 (0)