Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions generator/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
parseutil "gopkg.in/src-d/go-parse-utils.v1"
)

func mkField(name, typ string, fields ...*Field) *Field {
f := NewField(name, typ, reflect.StructTag(""))
func mkField(name, typ, tag string, fields ...*Field) *Field {
f := NewField(name, typ, reflect.StructTag(tag))
f.SetFields(fields)
return f
}
Expand Down Expand Up @@ -46,13 +46,9 @@ func withNode(f *Field, name string, typ types.Type) *Field {
return f
}

func withTag(f *Field, tag string) *Field {
f.Tag = reflect.StructTag(tag)
return f
}

func inline(f *Field) *Field {
return withTag(f, `kallax:",inline"`)
f.Tag = reflect.StructTag(`kallax:",inline"`)
return f
}

func processorFixture(source string) (*Processor, error) {
Expand Down
2 changes: 1 addition & 1 deletion generator/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ func (p *Processor) processModel(name string, s *types.Struct, t *types.Named) (
return nil, nil
}

p.processBaseField(m, fields[base])
if err := m.SetFields(fields); err != nil {
return nil, err
}

p.processBaseField(m, fields[base])
return m, nil
}

Expand Down
29 changes: 14 additions & 15 deletions generator/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,19 +384,23 @@ func (s *ProcessorSuite) TestIsEmbedded() {
type Bar struct {
kallax.Model
ID int64 ` + "`pk:\"autoincr\"`" + `
Bar string
Baz string
}

type Struct struct {
Bar Bar
Qux Bar
}

type Struct2 struct {
Mux string
}

type Foo struct {
kallax.Model
ID int64 ` + "`pk:\"autoincr\"`" + `
A Bar
B *Bar
Bar
Struct2
*Struct
C struct {
D int
Expand All @@ -405,21 +409,16 @@ func (s *ProcessorSuite) TestIsEmbedded() {
`
pkg := s.processFixture(src)
m := findModel(pkg, "Foo")
cases := []struct {
field string
embedded bool
}{
{"Model", true},
{"A", false},
{"B", false},
{"Bar", true},
{"Struct", true},
{"C", false},
expected := []string{
"ID", "Model", "A", "B", "Mux", "Qux", "C",
}

for _, c := range cases {
s.Equal(c.embedded, findField(m, c.field).IsEmbedded, c.field)
var names []string
for _, f := range m.Fields {
names = append(names, f.Name)
}

s.Equal(expected, names)
}

func TestProcessor(t *testing.T) {
Expand Down
19 changes: 10 additions & 9 deletions generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,16 @@ func addTemplate(base *template.Template, name string, filename string) *templat
return template.Must(base.New(name).Parse(text))
}

var base *template.Template = makeTemplate("base", "templates/base.tgo")
var schema *template.Template = addTemplate(base, "schema", "templates/schema.tgo")
var model *template.Template = addTemplate(base, "model", "templates/model.tgo")
var query *template.Template = addTemplate(model, "query", "templates/query.tgo")
var resultset *template.Template = addTemplate(model, "resultset", "templates/resultset.tgo")
var (
base = makeTemplate("base", "templates/base.tgo")
schema = addTemplate(base, "schema", "templates/schema.tgo")
model = addTemplate(base, "model", "templates/model.tgo")
query = addTemplate(model, "query", "templates/query.tgo")
resultset = addTemplate(model, "resultset", "templates/resultset.tgo")
)

// Base is the default Template instance with all templates preloaded.
var Base *Template = &Template{template: base}
var Base = &Template{template: base}

const (
// tplFindByCollection is the template of the FindBy autogenerated for
Expand Down Expand Up @@ -709,10 +711,9 @@ func shortName(pkg *types.Package, typ types.Type) string {

if specialName, ok := specialTypeShortName(typ); ok {
return prefix + specialName
} else {
shortName := typeString(typ, pkg)
return prefix + strings.Replace(shortName, "*", "", -1)
}
shortName := typeString(typ, pkg)
return prefix + strings.Replace(shortName, "*", "", -1)
}

// isEqualizable returns true if the autogenerated FindBy will use an equal query
Expand Down
135 changes: 122 additions & 13 deletions generator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,20 @@ func (m *Model) CtorRetVars() string {
return strings.Join(ret, ", ")
}

// SetFields sets all the children fields and their model to the current model.
// SetFields sets all the children fields and their model to the current
// model.
// It also finds the primary key and sets it in the model.
// It will return an error if more than one primary key is found.
// SetFields always sets the primary key as the first field of the model.
// So, all models can expect to have the primary key in the position 0 of
// their field slice. This is because the Store will expect the ID in that
// position.
func (m *Model) SetFields(fields []*Field) error {
var fs []*Field
var id *Field
for _, f := range fields {
for _, f := range flattenFields(fields) {
f.Model = m
if f.IsPrimaryKey() {
if f.IsPrimaryKey() && f.Type != BaseModel {
if id != nil {
return fmt.Errorf(
"kallax: found more than one primary key in model %s: %s and %s",
Expand All @@ -489,14 +494,56 @@ func (m *Model) SetFields(fields []*Field) error {
}

id = f
m.ID = f
} else if f.IsPrimaryKey() {
if f.primaryKey == "" {
return fmt.Errorf(
"kallax: primary key defined in %s has no field name, but it must be specified",
f.Name,
)
}

// the pk is defined in the model, we need to collect the model
// and we'll look for the field afterwards, when we have collected
// all fields. The model is appended to the field set, though,
// because it will not act as a primary key.
id = f
fs = append(fs, f)
} else {
fs = append(fs, f)
}
}

// if the id is a Model we need to look for the specified field
if id != nil && id.Type == BaseModel {
for i, f := range fs {
if f.columnName == id.primaryKey {
f.isPrimaryKey = true
f.isAutoincrement = id.isAutoincrement
id = f

if len(fs)-1 == i {
fs = append(fs[:i])
} else {
fs = append(fs[:i], fs[i+1:]...)
}
break
}
}

// If the ID is still a base model, means we did not find the pk
// field.
if id.Type == BaseModel {
return fmt.Errorf(
"kallax: the primary key was supposed to be %s according to the pk definition in %s, but the field could not be found",
id.primaryKey,
id.Name,
)
}
}

if id != nil {
m.Fields = []*Field{id}
m.ID = id
}
m.Fields = append(m.Fields, fs...)
return nil
Expand Down Expand Up @@ -556,6 +603,8 @@ func relationshipsOnFields(fields []*Field) []*Field {
return result
}

// ImplicitFK is a foreign key that is defined on just one side of the
// relationship and needs to be added on the other side.
type ImplicitFK struct {
Name string
Type string
Expand Down Expand Up @@ -590,6 +639,11 @@ type Field struct {
// A struct is considered embedded if and only if the struct was embedded
// as defined in Go.
IsEmbedded bool

primaryKey string
isPrimaryKey bool
isAutoincrement bool
columnName string
}

// FieldKind is the kind of a field.
Expand Down Expand Up @@ -645,13 +699,49 @@ func (t FieldKind) String() string {

// NewField creates a new field with its name, type and struct tag.
func NewField(n, t string, tag reflect.StructTag) *Field {
pkName, autoincr, isPrimaryKey := pkProperties(tag)

return &Field{
Name: n,
Type: t,
Tag: tag,

primaryKey: pkName,
columnName: columnName(n, tag),
isPrimaryKey: isPrimaryKey,
isAutoincrement: autoincr,
}
}

// pkProperties returns the primary key properties from a struct tag.
// Valid primary key definitions are the following:
// - pk:"" -> non-autoincr primary key without a field name.
// - pk:"autoincr" -> autoincr primary key without a field name.
// - pk:"foobar" -> non-autoincr primary key with a field name.
// - pk:"foobar,autoincr" -> autoincr primary key with a field name.
func pkProperties(tag reflect.StructTag) (name string, autoincr, isPrimaryKey bool) {
val, ok := tag.Lookup("pk")
if !ok {
return
}

isPrimaryKey = true
if val == "autoincr" || val == "" {
if val == "autoincr" {
autoincr = true
}
return
}

parts := strings.Split(val, ",")
name = parts[0]
if len(parts) > 1 && parts[1] == "autoincr" {
autoincr = true
}

return
}

// SetFields sets all the children fields and the current field as a parent of
// the children.
func (f *Field) SetFields(sf []*Field) {
Expand All @@ -667,16 +757,20 @@ func (f *Field) SetFields(sf []*Field) {
// is the field name converted to lower snake case.
// If the resultant name is a reserved keyword a _ will be prepended to the name.
func (f *Field) ColumnName() string {
name := strings.TrimSpace(strings.Split(f.Tag.Get("kallax"), ",")[0])
if name == "" {
name = toLowerSnakeCase(f.Name)
return f.columnName
}

func columnName(name string, tag reflect.StructTag) string {
n := strings.TrimSpace(strings.Split(tag.Get("kallax"), ",")[0])
if n == "" {
n = toLowerSnakeCase(name)
}

if _, ok := reservedKeywords[strings.ToLower(name)]; ok {
name = "_" + name
if _, ok := reservedKeywords[strings.ToLower(n)]; ok {
n = "_" + n
}

return name
return n
}

// ForeignKey returns the name of the foreign keys as specified in the struct
Expand All @@ -699,13 +793,12 @@ func (f *Field) ForeignKey() string {

// IsPrimaryKey reports whether the field is the primary key.
func (f *Field) IsPrimaryKey() bool {
_, ok := f.Tag.Lookup("pk")
return ok
return f.isPrimaryKey
}

// IsAutoIncrement reports whether the field is an autoincrementable primary key.
func (f *Field) IsAutoIncrement() bool {
return f.Tag.Get("pk") == "autoincr"
return f.isAutoincrement
}

// IsInverse returns whether the field is an inverse relationship.
Expand Down Expand Up @@ -1003,6 +1096,22 @@ func toLowerSnakeCase(s string) string {
return buf.String()
}

// flattenFields will recursively flatten all fields removing the embedded ones
// from the field set.
func flattenFields(fields []*Field) []*Field {
var result = make([]*Field, 0, len(fields))

for _, f := range fields {
if f.IsEmbedded && f.Type != BaseModel {
result = append(result, flattenFields(f.Fields)...)
} else {
result = append(result, f)
}
}

return result
}

// Event is the name of an event.
type Event string

Expand Down
Loading