Skip to content

Commit 5800834

Browse files
mx-psievan-bradley
andauthored
[chore] Make mapstructure hooks safe against untyped nils (#13001)
<!--Ex. Fixing a bug - Describe the bug and how this fixes the issue. Ex. Adding a feature - Explain what this achieves.--> #### Description If we enable `DecodeNil` as true, we may have [untyped nils](https://go.dev/doc/faq#nil_error) being passed. Unfortunately, these are not valid values for `reflect`, which leads to surprising behavior such as golang/go/issues/51649. Unfortunately, the default hooks from mapstructure do not deal with this properly. To account for this, we: - Vendor and change the `ComposeDecodeHookFunc` function so that this case is accounted for the kinds of hooks that just won't work with untyped nils - Create a safe wrapper for the hooks that do work with untyped nils. This wrapper is used in all hooks, but in the interest of keeping as close to what I would imagine upstream will accept, I did not add this to the compose function. This should not have any end-user observable behavior. <!-- Issue number if applicable --> #### Link to tracking issue Attempt to work around #12996 (comment) --------- Co-authored-by: Evan Bradley <[email protected]>
1 parent 6bd77b3 commit 5800834

File tree

2 files changed

+132
-11
lines changed

2 files changed

+132
-11
lines changed

confmap/confmap.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/knadh/koanf/v2"
2020

2121
encoder "go.opentelemetry.io/collector/confmap/internal/mapstructure"
22+
"go.opentelemetry.io/collector/confmap/internal/third_party/composehook"
2223
)
2324

2425
const (
@@ -234,7 +235,7 @@ func decodeConfig(m *Conf, result any, errorUnused bool, skipTopLevelUnmarshaler
234235
TagName: MapstructureTag,
235236
WeaklyTypedInput: false,
236237
MatchName: caseSensitiveMatchName,
237-
DecodeHook: mapstructure.ComposeDecodeHookFunc(
238+
DecodeHook: composehook.ComposeDecodeHookFunc(
238239
useExpandValue(),
239240
expandNilStructPointersHookFunc(),
240241
mapstructure.StringToSliceHookFunc(","),
@@ -306,6 +307,23 @@ func isStringyStructure(t reflect.Type) bool {
306307
return false
307308
}
308309

310+
// safeWrapDecodeHookFunc wraps a DecodeHookFuncValue to ensure fromVal is a valid `reflect.Value`
311+
// object and therefore it is safe to call `reflect.Value` methods on fromVal.
312+
//
313+
// Use this only if the hook does not need to be called on untyped nil values.
314+
// Typed nil values are safe to call and will be passed to the hook.
315+
// See https://github.com/golang/go/issues/51649
316+
func safeWrapDecodeHookFunc(
317+
f mapstructure.DecodeHookFuncValue,
318+
) mapstructure.DecodeHookFuncValue {
319+
return func(fromVal reflect.Value, toVal reflect.Value) (any, error) {
320+
if !fromVal.IsValid() {
321+
return nil, nil
322+
}
323+
return f(fromVal, toVal)
324+
}
325+
}
326+
309327
// When a value has been loaded from an external source via a provider, we keep both the
310328
// parsed value and the original string value. This allows us to expand the value to its
311329
// original string representation when decoding into a string field, and use the original otherwise.
@@ -355,7 +373,7 @@ func useExpandValue() mapstructure.DecodeHookFuncType {
355373
// we want an unmarshaled Config to be equivalent to
356374
// Config{Thing: &SomeStruct{}} instead of Config{Thing: nil}
357375
func expandNilStructPointersHookFunc() mapstructure.DecodeHookFuncValue {
358-
return func(from reflect.Value, to reflect.Value) (any, error) {
376+
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
359377
// ensure we are dealing with map to map comparison
360378
if from.Kind() == reflect.Map && to.Kind() == reflect.Map {
361379
toElem := to.Type().Elem()
@@ -375,7 +393,7 @@ func expandNilStructPointersHookFunc() mapstructure.DecodeHookFuncValue {
375393
}
376394
}
377395
return from.Interface(), nil
378-
}
396+
})
379397
}
380398

381399
// mapKeyStringToMapKeyTextUnmarshalerHookFunc returns a DecodeHookFuncType that checks that a conversion from
@@ -422,7 +440,7 @@ func mapKeyStringToMapKeyTextUnmarshalerHookFunc() mapstructure.DecodeHookFuncTy
422440
// unmarshalerEmbeddedStructsHookFunc provides a mechanism for embedded structs to define their own unmarshal logic,
423441
// by implementing the Unmarshaler interface.
424442
func unmarshalerEmbeddedStructsHookFunc() mapstructure.DecodeHookFuncValue {
425-
return func(from reflect.Value, to reflect.Value) (any, error) {
443+
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
426444
if to.Type().Kind() != reflect.Struct {
427445
return from.Interface(), nil
428446
}
@@ -455,14 +473,14 @@ func unmarshalerEmbeddedStructsHookFunc() mapstructure.DecodeHookFuncValue {
455473
}
456474
}
457475
return fromAsMap, nil
458-
}
476+
})
459477
}
460478

461479
// Provides a mechanism for individual structs to define their own unmarshal logic,
462480
// by implementing the Unmarshaler interface, unless skipTopLevelUnmarshaler is
463481
// true and the struct matches the top level object being unmarshaled.
464482
func unmarshalerHookFunc(result any, skipTopLevelUnmarshaler bool) mapstructure.DecodeHookFuncValue {
465-
return func(from reflect.Value, to reflect.Value) (any, error) {
483+
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
466484
if !to.CanAddr() {
467485
return from.Interface(), nil
468486
}
@@ -495,14 +513,14 @@ func unmarshalerHookFunc(result any, skipTopLevelUnmarshaler bool) mapstructure.
495513
}
496514

497515
return unmarshaler, nil
498-
}
516+
})
499517
}
500518

501519
// marshalerHookFunc returns a DecodeHookFuncValue that checks structs that aren't
502520
// the original to see if they implement the Marshaler interface.
503521
func marshalerHookFunc(orig any) mapstructure.DecodeHookFuncValue {
504522
origType := reflect.TypeOf(orig)
505-
return func(from reflect.Value, _ reflect.Value) (any, error) {
523+
return safeWrapDecodeHookFunc(func(from reflect.Value, _ reflect.Value) (any, error) {
506524
if from.Kind() != reflect.Struct {
507525
return from.Interface(), nil
508526
}
@@ -520,7 +538,7 @@ func marshalerHookFunc(orig any) mapstructure.DecodeHookFuncValue {
520538
return nil, err
521539
}
522540
return conf.ToStringMap(), nil
523-
}
541+
})
524542
}
525543

526544
// Unmarshaler interface may be implemented by types to customize their behavior when being unmarshaled from a Conf.
@@ -562,7 +580,7 @@ type Marshaler interface {
562580
// 4. configuration have no `keys` field specified, the output should be default config
563581
// - for example, input is {}, then output is Config{ Keys: ["a", "b"]}
564582
func zeroSliceHookFunc() mapstructure.DecodeHookFuncValue {
565-
return func(from reflect.Value, to reflect.Value) (any, error) {
583+
return safeWrapDecodeHookFunc(func(from reflect.Value, to reflect.Value) (any, error) {
566584
if to.CanSet() && to.Kind() == reflect.Slice && from.Kind() == reflect.Slice {
567585
if from.IsNil() {
568586
// input slice is nil, set output slice to nil.
@@ -574,7 +592,7 @@ func zeroSliceHookFunc() mapstructure.DecodeHookFuncValue {
574592
}
575593

576594
return from.Interface(), nil
577-
}
595+
})
578596
}
579597

580598
type moduleFactory[T any, S any] interface {
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright (c) 2013 Mitchell Hashimoto
2+
// SPDX-License-Identifier: MIT
3+
// This code is a modified version of https://github.com/go-viper/mapstructure
4+
5+
package composehook // import "go.opentelemetry.io/collector/confmap/internal/third_party/composehook"
6+
7+
import (
8+
"errors"
9+
"reflect"
10+
11+
"github.com/go-viper/mapstructure/v2"
12+
)
13+
14+
// typedDecodeHook takes a raw DecodeHookFunc (an any) and turns
15+
// it into the proper DecodeHookFunc type, such as DecodeHookFuncType.
16+
func typedDecodeHook(h mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc {
17+
// Create variables here so we can reference them with the reflect pkg
18+
var f1 mapstructure.DecodeHookFuncType
19+
var f2 mapstructure.DecodeHookFuncKind
20+
var f3 mapstructure.DecodeHookFuncValue
21+
22+
// Fill in the variables into this interface and the rest is done
23+
// automatically using the reflect package.
24+
potential := []any{f3, f1, f2}
25+
26+
v := reflect.ValueOf(h)
27+
vt := v.Type()
28+
for _, raw := range potential {
29+
pt := reflect.ValueOf(raw).Type()
30+
if vt.ConvertibleTo(pt) {
31+
return v.Convert(pt).Interface()
32+
}
33+
}
34+
35+
return nil
36+
}
37+
38+
// cachedDecodeHook takes a raw DecodeHookFunc (an any) and turns
39+
// it into a closure to be used directly
40+
// if the type fails to convert we return a closure always erroring to keep the previous behavior
41+
func cachedDecodeHook(raw mapstructure.DecodeHookFunc) func(reflect.Value, reflect.Value) (any, error) {
42+
switch f := typedDecodeHook(raw).(type) {
43+
case mapstructure.DecodeHookFuncType:
44+
return func(from reflect.Value, to reflect.Value) (any, error) {
45+
// CHANGE FROM UPSTREAM: check if from is valid and return nil if not
46+
if !from.IsValid() {
47+
return nil, nil
48+
}
49+
return f(from.Type(), to.Type(), from.Interface())
50+
}
51+
case mapstructure.DecodeHookFuncKind:
52+
return func(from reflect.Value, to reflect.Value) (any, error) {
53+
// CHANGE FROM UPSTREAM: check if from is valid and return nil if not
54+
if !from.IsValid() {
55+
return nil, nil
56+
}
57+
return f(from.Kind(), to.Kind(), from.Interface())
58+
}
59+
case mapstructure.DecodeHookFuncValue:
60+
return func(from reflect.Value, to reflect.Value) (any, error) {
61+
return f(from, to)
62+
}
63+
default:
64+
return func(reflect.Value, reflect.Value) (any, error) {
65+
return nil, errors.New("invalid decode hook signature")
66+
}
67+
}
68+
}
69+
70+
// ComposeDecodeHookFunc creates a single DecodeHookFunc that
71+
// automatically composes multiple DecodeHookFuncs.
72+
//
73+
// The composed funcs are called in order, with the result of the
74+
// previous transformation.
75+
//
76+
// This is a copy of [mapstructure.ComposeDecodeHookFunc] but with
77+
// validation added.
78+
func ComposeDecodeHookFunc(fs ...mapstructure.DecodeHookFunc) mapstructure.DecodeHookFunc {
79+
cached := make([]func(reflect.Value, reflect.Value) (any, error), 0, len(fs))
80+
for _, f := range fs {
81+
cached = append(cached, cachedDecodeHook(f))
82+
}
83+
return func(f reflect.Value, t reflect.Value) (any, error) {
84+
var err error
85+
86+
// CHANGE FROM UPSTREAM: check if f is valid before calling f.Interface()
87+
var data any
88+
if f.IsValid() {
89+
data = f.Interface()
90+
}
91+
92+
newFrom := f
93+
for _, c := range cached {
94+
data, err = c(newFrom, t)
95+
if err != nil {
96+
return nil, err
97+
}
98+
newFrom = reflect.ValueOf(data)
99+
}
100+
101+
return data, nil
102+
}
103+
}

0 commit comments

Comments
 (0)