diff --git a/README.md b/README.md index 34c4d9d..486c99b 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,10 @@ targets: {{range .ColumnsSeq}}{{$col := $table.GetColumn .}} {{ColumnMapper $col.Name}} {{Type $col}} `{{Tag $table $col}}` {{end}} } + + func (m *{{TableMapper .Name}}) TableName() string { + return "{{$table.Name}}" + } {{end}} template_path: ./template/goxorm.tmpl # template path for code file, it has higher perior than template field on language output_dir: ./models # code output directory diff --git a/cmd/reverse.go b/cmd/reverse.go index 9795f32..2313823 100644 --- a/cmd/reverse.go +++ b/cmd/reverse.go @@ -67,6 +67,7 @@ type ReverseTarget struct { OutputDir string `yaml:"output_dir"` TablePrefix string `yaml:"table_prefix"` Language string `yaml:"language"` + TableName bool `yaml:"table_name"` Funcs map[string]string `yaml:"funcs"` Formatter string `yaml:"formatter"` @@ -178,7 +179,7 @@ func runReverse(source *ReverseSource, target *ReverseTarget) error { tables = filterTables(tables, target) // load configuration from language - lang := language.GetLanguage(target.Language) + lang := language.GetLanguage(target.Language, target.TableName) funcs := newFuncs() formatter := formatters[target.Formatter] importter := importters[target.Importter] @@ -196,6 +197,7 @@ func runReverse(source *ReverseSource, target *ReverseTarget) error { if lang != nil { if bs == nil { + bs = []byte(lang.Template) } for k, v := range lang.Funcs { diff --git a/cmd/reverse_test.go b/cmd/reverse_test.go index 19fd61b..1acd73e 100644 --- a/cmd/reverse_test.go +++ b/cmd/reverse_test.go @@ -17,16 +17,26 @@ import ( "xorm.io/xorm" ) -var result = fmt.Sprintf(`package models +var ( + result = fmt.Sprintf(`package models type A struct { - Id int %sxorm:"integer"%s + Id int %sxorm:"'Id' integer"%s +} + +func (m *A) TableName() string { + return "a" } type B struct { - Id int %sxorm:"INTEGER"%s + Id int %sxorm:"'Id' INTEGER"%s +} + +func (m *B) TableName() string { + return "b" } `, "`", "`", "`", "`") +) func TestReverse(t *testing.T) { err := reverse("../example/goxorm.yml") diff --git a/example/custom.yml b/example/custom.yml index d37c3f2..a64d706 100644 --- a/example/custom.yml +++ b/example/custom.yml @@ -5,6 +5,7 @@ source: conn_str: ../testdata/test.db targets: - type: codes + language: golang include_tables: - a - b @@ -14,6 +15,7 @@ targets: column_mapper: snake table_prefix: "" multiple_files: true + table_name: true template: | package models diff --git a/example/goxorm.yml b/example/goxorm.yml index d16f424..d6ffe24 100644 --- a/example/goxorm.yml +++ b/example/goxorm.yml @@ -11,4 +11,5 @@ targets: exclude_tables: - c language: golang + table_name: true output_dir: ../models \ No newline at end of file diff --git a/example/template/goxorm.tmpl b/example/template/goxorm.tmpl index 867d653..7f36fa5 100644 --- a/example/template/goxorm.tmpl +++ b/example/template/goxorm.tmpl @@ -13,4 +13,5 @@ type {{TableMapper .Name}} struct { {{range .ColumnsSeq}}{{$col := $table.GetColumn .}} {{ColumnMapper $col.Name}} {{Type $col}} `{{Tag $table $col}}` {{end}} } + {{end}} \ No newline at end of file diff --git a/example/template/goxorm_table.tmpl b/example/template/goxorm_table.tmpl new file mode 100644 index 0000000..fc998c8 --- /dev/null +++ b/example/template/goxorm_table.tmpl @@ -0,0 +1,20 @@ +package models + +{{$ilen := len .Imports}} +{{if gt $ilen 0}} +import ( + {{range .Imports}}"{{.}}"{{end}} +) +{{end}} + +{{range .Tables}} +type {{TableMapper .Name}} struct { +{{$table := .}} +{{range .ColumnsSeq}}{{$col := $table.GetColumn .}} {{ColumnMapper $col.Name}} {{Type $col}} `{{Tag $table $col}}` +{{end}} +} + +func (m *{{TableMapper .Name}}) TableName() string { + return "{{$table.Name}}" +} +{{end}} \ No newline at end of file diff --git a/language/golang.go b/language/golang.go index 3a6c5f0..12be583 100644 --- a/language/golang.go +++ b/language/golang.go @@ -50,6 +50,23 @@ type {{TableMapper .Name}} struct { {{end}} } {{end}} +`, "`", "`") + defaultGolangTemplateTable = fmt.Sprintf(`package models + +{{$ilen := len .Imports}}{{if gt $ilen 0}}import ( + {{range .Imports}}"{{.}}"{{end}} +){{end}} + +{{range .Tables}} +type {{TableMapper .Name}} struct { +{{$table := .}}{{range .ColumnsSeq}}{{$col := $table.GetColumn .}} {{ColumnMapper $col.Name}} {{Type $col}} %s{{Tag $table $col}}%s +{{end}} +} + +func (m *{{TableMapper .Name}}) TableName() string { + return "{{$table.Name}}" +} +{{end}} `, "`", "`") ) @@ -127,6 +144,7 @@ func tag(table *schemas.Table, col *schemas.Column) template.HTML { isIdPk := isNameId && typestring(col) == "int64" var res []string + res = append(res, "'"+col.FieldName+"'") if !col.Nullable { if !isIdPk { res = append(res, "not null") diff --git a/language/language.go b/language/language.go index 414a9d6..f726bea 100644 --- a/language/language.go +++ b/language/language.go @@ -31,6 +31,12 @@ func RegisterLanguage(l *Language) { } // GetLanguage returns a language if exists -func GetLanguage(name string) *Language { - return languages[name] +func GetLanguage(name string, tableName bool) *Language { + language := languages[name] + if tableName { + language = languages[name] + language.Template = defaultGolangTemplateTable + return language + } + return language }