Skip to content

Commit 7fd2ca1

Browse files
Support reference cycles (#393)
1 parent ed98f50 commit 7fd2ca1

File tree

2 files changed

+87
-5
lines changed

2 files changed

+87
-5
lines changed

openapi3gen/openapi3gen.go

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openapi3gen
33

44
import (
55
"encoding/json"
6+
"fmt"
67
"math"
78
"reflect"
89
"strings"
@@ -22,6 +23,7 @@ type Option func(*generatorOpt)
2223

2324
type generatorOpt struct {
2425
useAllExportedFields bool
26+
throwErrorOnCycle bool
2527
}
2628

2729
// UseAllExportedFields changes the default behavior of only
@@ -30,6 +32,12 @@ func UseAllExportedFields() Option {
3032
return func(x *generatorOpt) { x.useAllExportedFields = true }
3133
}
3234

35+
// ThrowErrorOnCycle changes the default behavior of creating cycle
36+
// refs to instead error if a cycle is detected.
37+
func ThrowErrorOnCycle() Option {
38+
return func(x *generatorOpt) { x.throwErrorOnCycle = true }
39+
}
40+
3341
// NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef.
3442
func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) {
3543
g := NewGenerator(opts...)
@@ -104,6 +112,10 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
104112
if a && b {
105113
vs, err := g.generateSchemaRefFor(parents, v.Type)
106114
if err != nil {
115+
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
116+
g.SchemaRefs[vs]++
117+
return vs, nil
118+
}
107119
return nil, err
108120
}
109121
refSchemaRef := RefSchemaRef
@@ -185,7 +197,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
185197
schema.Type = "array"
186198
items, err := g.generateSchemaRefFor(parents, t.Elem())
187199
if err != nil {
188-
return nil, err
200+
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
201+
items = g.generateCycleSchemaRef(t.Elem(), schema)
202+
} else {
203+
return nil, err
204+
}
189205
}
190206
if items != nil {
191207
g.SchemaRefs[items]++
@@ -197,7 +213,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
197213
schema.Type = "object"
198214
additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem())
199215
if err != nil {
200-
return nil, err
216+
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
217+
additionalProperties = g.generateCycleSchemaRef(t.Elem(), schema)
218+
} else {
219+
return nil, err
220+
}
201221
}
202222
if additionalProperties != nil {
203223
g.SchemaRefs[additionalProperties]++
@@ -221,7 +241,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
221241
if t.Field(fieldInfo.Index[0]).Anonymous {
222242
ref, err := g.generateSchemaRefFor(parents, fType)
223243
if err != nil {
224-
return nil, err
244+
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
245+
ref = g.generateCycleSchemaRef(fType, schema)
246+
} else {
247+
return nil, err
248+
}
225249
}
226250
if ref != nil {
227251
g.SchemaRefs[ref]++
@@ -237,7 +261,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
237261

238262
ref, err := g.generateSchemaRefFor(parents, fType)
239263
if err != nil {
240-
return nil, err
264+
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
265+
ref = g.generateCycleSchemaRef(fType, schema)
266+
} else {
267+
return nil, err
268+
}
241269
}
242270
if ref != nil {
243271
g.SchemaRefs[ref]++
@@ -255,6 +283,30 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
255283
return openapi3.NewSchemaRef(t.Name(), schema), nil
256284
}
257285

286+
func (g *Generator) generateCycleSchemaRef(t reflect.Type, schema *openapi3.Schema) *openapi3.SchemaRef {
287+
var typeName string
288+
switch t.Kind() {
289+
case reflect.Ptr:
290+
return g.generateCycleSchemaRef(t.Elem(), schema)
291+
case reflect.Slice:
292+
ref := g.generateCycleSchemaRef(t.Elem(), schema)
293+
sliceSchema := openapi3.NewSchema()
294+
sliceSchema.Type = "array"
295+
sliceSchema.Items = ref
296+
return openapi3.NewSchemaRef("", sliceSchema)
297+
case reflect.Map:
298+
ref := g.generateCycleSchemaRef(t.Elem(), schema)
299+
mapSchema := openapi3.NewSchema()
300+
mapSchema.Type = "object"
301+
mapSchema.AdditionalProperties = ref
302+
return openapi3.NewSchemaRef("", mapSchema)
303+
default:
304+
typeName = t.Name()
305+
}
306+
307+
return openapi3.NewSchemaRef(fmt.Sprintf("#/components/schemas/%s", typeName), schema)
308+
}
309+
258310
var RefSchemaRef = openapi3.NewSchemaRef("Ref",
259311
openapi3.NewObjectSchema().WithProperty("$ref", openapi3.NewStringSchema().WithMinLength(1)))
260312

openapi3gen/openapi3gen_test.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type CyclicType1 struct {
1616
}
1717

1818
func TestCyclic(t *testing.T) {
19-
schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{})
19+
schemaRef, refsMap, err := NewSchemaRefForValue(&CyclicType0{}, ThrowErrorOnCycle())
2020
require.IsType(t, &CycleError{}, err)
2121
require.Nil(t, schemaRef)
2222
require.Empty(t, refsMap)
@@ -84,3 +84,33 @@ func TestEmbeddedStructs(t *testing.T) {
8484
_, ok = schemaRef.Value.Properties["ID"]
8585
require.Equal(t, true, ok)
8686
}
87+
88+
func TestCyclicReferences(t *testing.T) {
89+
type ObjectDiff struct {
90+
FieldCycle *ObjectDiff
91+
SliceCycle []*ObjectDiff
92+
MapCycle map[*ObjectDiff]*ObjectDiff
93+
}
94+
95+
instance := &ObjectDiff{
96+
FieldCycle: nil,
97+
SliceCycle: nil,
98+
MapCycle: nil,
99+
}
100+
101+
generator := NewGenerator(UseAllExportedFields())
102+
103+
schemaRef, err := generator.GenerateSchemaRef(reflect.TypeOf(instance))
104+
require.NoError(t, err)
105+
106+
require.NotNil(t, schemaRef.Value.Properties["FieldCycle"])
107+
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["FieldCycle"].Ref)
108+
109+
require.NotNil(t, schemaRef.Value.Properties["SliceCycle"])
110+
require.Equal(t, "array", schemaRef.Value.Properties["SliceCycle"].Value.Type)
111+
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["SliceCycle"].Value.Items.Ref)
112+
113+
require.NotNil(t, schemaRef.Value.Properties["MapCycle"])
114+
require.Equal(t, "object", schemaRef.Value.Properties["MapCycle"].Value.Type)
115+
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["MapCycle"].Value.AdditionalProperties.Ref)
116+
}

0 commit comments

Comments
 (0)