diff --git a/main.go b/main.go index c074a87..9c358fd 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,11 @@ package main import ( "database/sql" + "flag" "fmt" "log" + "os" + "path/filepath" "strconv" "strings" "unicode" @@ -48,6 +51,16 @@ var commonInitialisms = map[string]bool{ "DB": true, } +var flagDbFile = flag.String("db-file", "./example.db", "path to the DB") +var flagOut = flag.String("out", "./gen", "output file for generated files") +var flagIgnoreColumns = flag.String("skip", "rowid,_rowid_,_rid,rid", "list of columns to be excluded from struct generation") +var flagGenJson = flag.Bool("json", true, "generate JSON annotation") +var flagGenDb = flag.Bool("db", true, "generate DB annotation") +var flagGenGorm = flag.Bool("gorm", true, "generate GORM annotation") +var flagPkgName = flag.String("pkg", "def", "specify package name") + +var ignoreColumns []string + var intToWordMap = []string{ "zero", "one", @@ -62,16 +75,37 @@ var intToWordMap = []string{ } func main() { - db, err := sql.Open("sqlite3", "./example.db") + flag.Parse() + + db, err := sql.Open("sqlite3", *flagDbFile) + + ignoreColumns = strings.Split(strings.ToLower(*flagIgnoreColumns), ",") + if err != nil { log.Fatal(err) } + defer db.Close() + tableNames := getTableNames(db) + outPath := filepath.Clean(*flagOut) + errOs := os.MkdirAll(outPath, 0770) + + if errOs != nil { + log.Fatal(errOs) + } + + os.Create(outPath) + c := 0 + fmt.Printf("Generating code for the following tables (%d)\n", len(tableNames)) for _, tableName := range tableNames { - file := scanTableStructure(db, tableName) + c++ + fmt.Printf("[%d] %s\n", c, tableName) + + file := scanTableStructure(db, tableName, outPath, *flagPkgName) structureName := formatFieldName(tableName) - err = file.Save("gen/" + fmt.Sprintf("%s.go", structureName)) + fileName := filepath.Join(outPath, fmt.Sprintf("%s.go", structureName)) + err = file.Save(fileName) } if err != nil { @@ -79,8 +113,8 @@ func main() { } } -func scanTableStructure(db *sql.DB, tableName string) *jen.File { - file := jen.NewFilePathName("gen", "def") +func scanTableStructure(db *sql.DB, tableName string, outPath string, packageName string) *jen.File { + file := jen.NewFilePathName(outPath, packageName) structureName := formatFieldName(tableName) file.Comment(fmt.Sprintf("// %s represent database table (%s)", structureName, tableName)) file.Type().Id(structureName).Struct( @@ -120,8 +154,7 @@ func generateTableFields(db *sql.DB, tableName string) *[]jen.Code { } defer rows.Close() - fmt.Println(getTableNames(db)) - + //fmt.Println(getTableNames(db)) for rows.Next() { var cid int var name string @@ -129,40 +162,62 @@ func generateTableFields(db *sql.DB, tableName string) *[]jen.Code { var notnull string var dfltValue sql.NullString var pk string + var ignore bool err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk) if err != nil { log.Fatal(err) } - fmt.Println(cid, name, ctype, notnull, dfltValue, pk) - field := jen.Id(formatFieldName(name)) + + ignore = isIgnoreField(name) + + //fmt.Println(cid, name, ctype, notnull, dfltValue, pk) + name2 := formatFieldName(name) + if ignore { + name2 = `// ` + name2 + } + field := jen.Id(name2) setFieldType(field, ctype) setFieldTags(field, name) fields = append(fields, field) } + return &fields } func setFieldTags(field *jen.Statement, name string) { - field.Tag( - map[string]string{ - "json": name, - "gorm": fmt.Sprintf("column:%s", name), - }, - ) + m := map[string]string{} + + if *flagGenDb { + m["db"] = name + } + + if *flagGenGorm { + m["gorm"] = fmt.Sprintf("column:%s", name) + } + + if *flagGenJson { + m["json"] = name + } + + field.Tag(m) } func setFieldType(field *jen.Statement, ctype string) { - dbType := strings.Split(ctype, "(")[0] + dbType := strings.ToUpper(strings.Split(ctype, "(")[0]) switch dbType { case "VARCHAR", "TEXT": field.String() - case "BOOL": + case "BOOL", "BOOLEAN": field.Bool() - case "INTEGER": + case "TINYINT", "SMALLINT": field.Int32() - case "FLOAT": + case "INTEGER", "INT", "INT2", "MEDIUMINT", "BIGINT", "UNSIGNED BIG INT", "INT8": + field.Int64() + case "REAL", "DOUBLE", "DOUBLE PRECISION", "FLOAT": field.Float32() + case "NUMERIC", "DECIMAL", "DECIMAL(10,5)": + field.Float64() default: field.String() } @@ -210,6 +265,7 @@ func lintFieldName(name string) string { break } } + if allLower { runes := []rune(name) if u := strings.ToUpper(name); commonInitialisms[u] { @@ -227,6 +283,7 @@ func lintFieldName(name string) string { break } } + if allUpperWithUnderscore { name = strings.ToLower(name) } @@ -291,3 +348,13 @@ func stringifyFirstChar(str string) string { return intToWordMap[i] + "_" + str[1:] } + +func isIgnoreField(fieldName string) bool { + lowerFieldName := strings.ToLower(fieldName) + for _, v := range ignoreColumns { + if v == lowerFieldName { + return true + } + } + return false +}