Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions batcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,18 @@ func (r *batchQueryRunner) getRecordThroughRelationships(ids []interface{}, rel
return nil, fmt.Errorf("kallax: cannot find foreign keys for through relationship on field %s of table %s", rel.Field, r.schema.Table())
}

filter := In(r.schema.ID(), ids...)
q := NewBaseQuery(rel.Schema)
lschema := r.schema.WithAlias(rel.Schema.Alias())
intSchema := rel.IntermediateSchema.WithAlias(rel.Schema.Alias())
q.joinThrough(lschema, intSchema, rel.Schema, lfk, rfk)
q.where(In(lschema.ID(), ids...), lschema)
if rel.Filter != nil {
filter = And(rel.Filter, filter)
q.Where(rel.Filter)
}

if rel.IntermediateFilter != nil {
filter = And(rel.IntermediateFilter, filter)
q.where(rel.IntermediateFilter, intSchema)
}

q := NewBaseQuery(rel.Schema)
lschema := r.schema.WithAlias(rel.Schema.Alias())
intSchema := rel.IntermediateSchema.WithAlias(rel.Schema.Alias())
q.joinThrough(lschema, intSchema, rel.Schema, lfk, rfk)
q.Where(filter)
cols, builder := q.compile()
// manually add the extra column to also select the parent id
builder = builder.Column(lschema.ID().QualifiedName(lschema))
Expand Down
5 changes: 5 additions & 0 deletions generator/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ func (p *Processor) processPackage() (*Package, error) {
if err := pkg.addMissingRelationships(); err != nil {
return nil, err
}

if err := pkg.addThroughModels(); err != nil {
return nil, err
}

for _, ctor := range ctors {
p.tryMatchConstructor(pkg, ctor)
}
Expand Down
32 changes: 27 additions & 5 deletions generator/templates/query.tgo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// {{.QueryName}} is the object used to create queries for the {{.Name}}
// {{.QueryName}} is the object used to create queries for the {{.Name}}
// entity.
type {{.QueryName}} struct {
*kallax.BaseQuery
Expand Down Expand Up @@ -68,15 +68,37 @@ func (q *{{.QueryName}}) Where(cond kallax.Condition) *{{.QueryName}} {
}

{{range .Relationships}}
{{if not .IsOneToManyRelationship}}
func (q *{{$.QueryName}}) With{{.Name}}() *{{$.QueryName}} {
q.AddRelation(Schema.{{.TypeSchemaName}}.BaseSchema, "{{.Name}}", kallax.OneToOne, nil)
{{if .IsManyToManyRelationship}}
// With{{.Name}} retrieves all the {{.Name}} records associated with each
// record. Two conditions can be passed, the first to filter the table used
// to join {{.Name}} and {{.Model.Name}} and the second one to filter
// {{.Name}} directly.
func (q *{{$.QueryName}}) With{{.Name}}(
filter{{.ThroughSchemaName}} kallax.Condition,
filter{{.TypeSchemaName}} kallax.Condition,
) *{{$.QueryName}} {
q.AddRelationThrough(
Schema.{{.TypeSchemaName}}.BaseSchema,
Schema.{{.ThroughSchemaName}}.BaseSchema,
"{{.Name}}",
filter{{.ThroughSchemaName}},
filter{{.TypeSchemaName}},
)
return q
}
{{else}}
{{else if .IsOneToManyRelationship}}
// With{{.Name}} retrieves all the {{.Name}} records associated with each
// record. A condition can be passed to filter the associated records.
func (q *{{$.QueryName}}) With{{.Name}}(cond kallax.Condition) *{{$.QueryName}} {
q.AddRelation(Schema.{{.TypeSchemaName}}.BaseSchema, "{{.Name}}", kallax.OneToMany, cond)
return q
}
{{else}}
// With{{.Name}} retrieves the {{.Name}} record associated with each
// record.
func (q *{{$.QueryName}}) With{{.Name}}() *{{$.QueryName}} {
q.AddRelation(Schema.{{.TypeSchemaName}}.BaseSchema, "{{.Name}}", kallax.OneToOne, nil)
return q
}
{{end}}
{{end}}
44 changes: 38 additions & 6 deletions generator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type Package struct {
// Models are all the models found in the package.
Models []*Model
indexedModels map[string]*Model
modelsByTable map[string]*Model
}

// NewPackage creates a new package.
Expand All @@ -137,13 +138,15 @@ func NewPackage(pkg *types.Package) *Package {
Name: pkg.Name(),
pkg: pkg,
indexedModels: make(map[string]*Model),
modelsByTable: make(map[string]*Model),
}
}

// SetModels sets the models of the packages and indexes them.
func (p *Package) SetModels(models []*Model) {
for _, m := range models {
p.indexedModels[m.Name] = m
p.modelsByTable[m.Table] = m
}
p.Models = models
}
Expand All @@ -153,20 +156,41 @@ func (p *Package) FindModel(name string) *Model {
return p.indexedModels[name]
}

func (p *Package) addMissingRelationships() error {
func (p *Package) forEachModelField(fn func(m *Model, f *Field) error) error {
for _, m := range p.Models {
for _, f := range m.Fields {
if f.Kind == Relationship && !f.IsInverse() {
if err := p.trySetFK(f.TypeSchemaName(), f); err != nil {
return err
}
if err := fn(m, f); err != nil {
return err
}
}
}

return nil
}

func (p *Package) addThroughModels() error {
return p.forEachModelField(func(m *Model, f *Field) error {
if f.IsManyToManyRelationship() {
model, ok := p.modelsByTable[f.ThroughTable()]
if !ok {
return fmt.Errorf("kallax: cannot find a model with table name %s to access field %s of model %s", f.ThroughTable(), f.Name, m.Name)
}
f.ThroughModel = model
}
return nil
})
}

func (p *Package) addMissingRelationships() error {
return p.forEachModelField(func(m *Model, f *Field) error {
if f.Kind == Relationship && !f.IsInverse() && !f.IsManyToManyRelationship() {
if err := p.trySetFK(f.TypeSchemaName(), f); err != nil {
return err
}
}
return nil
})
}

func (p *Package) trySetFK(model string, fk *Field) error {
m := p.FindModel(model)
if m == nil {
Expand Down Expand Up @@ -631,6 +655,9 @@ type Field struct {
Parent *Field
// Model is the reference to the model containing this field.
Model *Model
// ThroughModel is the reference to the model through which the field is
// accessed in a relationship.
ThroughModel *Model
// IsPtr reports whether the field is a pointer type or not.
IsPtr bool
// IsJSON reports whether the field has to be converted to JSON.
Expand Down Expand Up @@ -1018,6 +1045,11 @@ func (f *Field) TypeSchemaName() string {
return parts[len(parts)-1]
}

// ThroughSchemaName returns the name of the Schema for the through model type.
func (f *Field) ThroughSchemaName() string {
return f.ThroughModel.Name
}

func (f *Field) SQLType() string {
return f.Tag.Get("sqltype")
}
Expand Down
6 changes: 5 additions & 1 deletion query.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,11 @@ func (q *BaseQuery) GetOffset() uint64 {
// q.Where(Gt(AgeColumn, 18))
// // ... WHERE name = "foo" AND age > 18
func (q *BaseQuery) Where(cond Condition) {
q.builder = q.builder.Where(cond(q.schema))
q.where(cond, q.schema)
}

func (q *BaseQuery) where(cond Condition, schema Schema) {
q.builder = q.builder.Where(cond(schema))
}

// compile returns the selected column names and the select builder.
Expand Down
Loading