From 33ccec6d686a4ca7c377e8cdfacae770faff84f5 Mon Sep 17 00:00:00 2001 From: John Guo Date: Sat, 14 Mar 2026 15:09:16 +0800 Subject: [PATCH] add sql file mode support to gendao command --- .gitignore | 3 +- Makefile | 22 + cmd/gf/internal/cmd/gendao/gendao.go | 276 ++++- cmd/gf/internal/cmd/gendao/gendao_clear.go | 8 + cmd/gf/internal/cmd/gendao/gendao_dao.go | 49 +- cmd/gf/internal/cmd/gendao/gendao_do.go | 9 +- cmd/gf/internal/cmd/gendao/gendao_entity.go | 8 +- cmd/gf/internal/cmd/gendao/gendao_gen_item.go | 23 +- .../internal/cmd/gendao/gendao_sql_parser.go | 1030 +++++++++++++++++ .../cmd/gendao/gendao_sql_parser_mssql.go | 211 ++++ .../gendao/gendao_sql_parser_mssql_test.go | 72 ++ .../cmd/gendao/gendao_sql_parser_mysql.go | 199 ++++ .../gendao/gendao_sql_parser_mysql_test.go | 300 +++++ .../cmd/gendao/gendao_sql_parser_oracle.go | 209 ++++ .../gendao/gendao_sql_parser_oracle_test.go | 97 ++ .../cmd/gendao/gendao_sql_parser_pgsql.go | 268 +++++ .../gendao/gendao_sql_parser_pgsql_test.go | 232 ++++ .../cmd/gendao/gendao_sql_parser_sqlite.go | 159 +++ .../gendao/gendao_sql_parser_sqlite_test.go | 112 ++ .../cmd/gendao/gendao_sql_parser_test.go | 302 +++++ .../internal/cmd/gendao/gendao_structure.go | 37 +- cmd/gf/internal/cmd/gendao/gendao_table.go | 2 +- cmd/gf/internal/cmd/gendao/gendao_tag.go | 36 +- database/gdb/gdb.go | 18 + database/gdb/gdb_core_structure.go | 48 +- 25 files changed, 3626 insertions(+), 104 deletions(-) create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql_test.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql_test.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle_test.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql_test.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite_test.go create mode 100644 cmd/gf/internal/cmd/gendao/gendao_sql_parser_test.go diff --git a/.gitignore b/.gitignore index be381afce..e434e969d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,5 @@ node_modules output .example/ .golangci.bck.yml -*.exe \ No newline at end of file +*.exe +.aiprompt.zh.md \ No newline at end of file diff --git a/Makefile b/Makefile index 2866f3a97..d9666d2e9 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,27 @@ SHELL := /bin/bash +# commit changes with AI-generated commit message +.PHONY: up +up: + @if git diff --quiet HEAD && git diff --cached --quiet && [ -z "$$(git ls-files --others --exclude-standard)" ]; then \ + echo "No changes to commit"; \ + exit 0; \ + fi + @git add -A + @echo "Analyzing changes and generating commit message via AI..." + @set -e; \ + MSG=$$(git diff --cached --stat && echo "---" && git diff --cached | head -2000 | \ + claude -p "Analyze the git diff above and generate a concise commit message (single line, max 72 chars, lowercase, no quotes). Output only the commit message itself, nothing else." \ + --model haiku) || { echo "Error: Claude command failed"; exit 1; }; \ + COMMIT_MSG=$$(echo "$$MSG" | tail -1); \ + if [ -z "$$COMMIT_MSG" ]; then \ + echo "Error: Failed to generate commit message"; \ + exit 1; \ + fi; \ + echo "Commit: $$COMMIT_MSG"; \ + git commit -m "$$COMMIT_MSG" && \ + git push origin + # execute "go mod tidy" on all folders that have go.mod file .PHONY: tidy tidy: diff --git a/cmd/gf/internal/cmd/gendao/gendao.go b/cmd/gf/internal/cmd/gendao/gendao.go index 056d91dcc..c13b3e7db 100644 --- a/cmd/gf/internal/cmd/gendao/gendao.go +++ b/cmd/gf/internal/cmd/gendao/gendao.go @@ -33,65 +33,88 @@ import ( ) type ( - CGenDao struct{} + // CGenDao is the command handler struct for "gen dao" command. + CGenDao struct{} + + // CGenDaoInput defines all input parameters for the "gen dao" command. + // It supports both command-line arguments and configuration file options. CGenDaoInput struct { g.Meta `name:"dao" config:"{CGenDaoConfig}" usage:"{CGenDaoUsage}" brief:"{CGenDaoBrief}" eg:"{CGenDaoEg}" ad:"{CGenDaoAd}"` - Path string `name:"path" short:"p" brief:"{CGenDaoBriefPath}" d:"internal"` - Link string `name:"link" short:"l" brief:"{CGenDaoBriefLink}"` - Tables string `name:"tables" short:"t" brief:"{CGenDaoBriefTables}"` - TablesEx string `name:"tablesEx" short:"x" brief:"{CGenDaoBriefTablesEx}"` - ShardingPattern []string `name:"shardingPattern" short:"sp" brief:"{CGenDaoBriefShardingPattern}"` - Group string `name:"group" short:"g" brief:"{CGenDaoBriefGroup}" d:"default"` - Prefix string `name:"prefix" short:"f" brief:"{CGenDaoBriefPrefix}"` - RemovePrefix string `name:"removePrefix" short:"r" brief:"{CGenDaoBriefRemovePrefix}"` - RemoveFieldPrefix string `name:"removeFieldPrefix" short:"rf" brief:"{CGenDaoBriefRemoveFieldPrefix}"` - JsonCase string `name:"jsonCase" short:"j" brief:"{CGenDaoBriefJsonCase}" d:"CamelLower"` - ImportPrefix string `name:"importPrefix" short:"i" brief:"{CGenDaoBriefImportPrefix}"` - DaoPath string `name:"daoPath" short:"d" brief:"{CGenDaoBriefDaoPath}" d:"dao"` - TablePath string `name:"tablePath" short:"tp" brief:"{CGenDaoBriefTablePath}" d:"table"` - DoPath string `name:"doPath" short:"o" brief:"{CGenDaoBriefDoPath}" d:"model/do"` - EntityPath string `name:"entityPath" short:"e" brief:"{CGenDaoBriefEntityPath}" d:"model/entity"` - TplDaoTablePath string `name:"tplDaoTablePath" short:"t0" brief:"{CGenDaoBriefTplDaoTablePath}"` - TplDaoIndexPath string `name:"tplDaoIndexPath" short:"t1" brief:"{CGenDaoBriefTplDaoIndexPath}"` - TplDaoInternalPath string `name:"tplDaoInternalPath" short:"t2" brief:"{CGenDaoBriefTplDaoInternalPath}"` - TplDaoDoPath string `name:"tplDaoDoPath" short:"t3" brief:"{CGenDaoBriefTplDaoDoPathPath}"` - TplDaoEntityPath string `name:"tplDaoEntityPath" short:"t4" brief:"{CGenDaoBriefTplDaoEntityPath}"` - StdTime bool `name:"stdTime" short:"s" brief:"{CGenDaoBriefStdTime}" orphan:"true"` - WithTime bool `name:"withTime" short:"w" brief:"{CGenDaoBriefWithTime}" orphan:"true"` - GJsonSupport bool `name:"gJsonSupport" short:"n" brief:"{CGenDaoBriefGJsonSupport}" orphan:"true"` - OverwriteDao bool `name:"overwriteDao" short:"v" brief:"{CGenDaoBriefOverwriteDao}" orphan:"true"` - DescriptionTag bool `name:"descriptionTag" short:"c" brief:"{CGenDaoBriefDescriptionTag}" orphan:"true"` - NoJsonTag bool `name:"noJsonTag" short:"k" brief:"{CGenDaoBriefNoJsonTag}" orphan:"true"` - NoModelComment bool `name:"noModelComment" short:"m" brief:"{CGenDaoBriefNoModelComment}" orphan:"true"` - Clear bool `name:"clear" short:"a" brief:"{CGenDaoBriefClear}" orphan:"true"` - GenTable bool `name:"genTable" short:"gt" brief:"{CGenDaoBriefGenTable}" orphan:"true"` + Path string `name:"path" short:"p" brief:"{CGenDaoBriefPath}" d:"internal"` // Base directory path for generated files. + Link string `name:"link" short:"l" brief:"{CGenDaoBriefLink}"` // Database connection string (e.g., "mysql:root:pass@tcp(127.0.0.1:3306)/db"). + Tables string `name:"tables" short:"t" brief:"{CGenDaoBriefTables}"` // Comma-separated table names or wildcard patterns to include. + TablesEx string `name:"tablesEx" short:"x" brief:"{CGenDaoBriefTablesEx}"` // Comma-separated table names or wildcard patterns to exclude. + ShardingPattern []string `name:"shardingPattern" short:"sp" brief:"{CGenDaoBriefShardingPattern}"` // Patterns for sharding tables (e.g., "users_?" merges users_001, users_002 into one dao). + Group string `name:"group" short:"g" brief:"{CGenDaoBriefGroup}" d:"default"` // Database configuration group name for ORM instance. + Prefix string `name:"prefix" short:"f" brief:"{CGenDaoBriefPrefix}"` // Prefix to add to all generated table names. + RemovePrefix string `name:"removePrefix" short:"r" brief:"{CGenDaoBriefRemovePrefix}"` // Comma-separated prefixes to remove from table names. + RemoveFieldPrefix string `name:"removeFieldPrefix" short:"rf" brief:"{CGenDaoBriefRemoveFieldPrefix}"` // Comma-separated prefixes to remove from field names. + JsonCase string `name:"jsonCase" short:"j" brief:"{CGenDaoBriefJsonCase}" d:"CamelLower"` // Naming convention for JSON tags (e.g., CamelLower, Snake). + ImportPrefix string `name:"importPrefix" short:"i" brief:"{CGenDaoBriefImportPrefix}"` // Custom Go import path prefix for generated files. + DaoPath string `name:"daoPath" short:"d" brief:"{CGenDaoBriefDaoPath}" d:"dao"` // Sub-directory under Path for dao files. + TablePath string `name:"tablePath" short:"tp" brief:"{CGenDaoBriefTablePath}" d:"table"` // Sub-directory under Path for table field definition files. + DoPath string `name:"doPath" short:"o" brief:"{CGenDaoBriefDoPath}" d:"model/do"` // Sub-directory under Path for DO (Data Object) files. + EntityPath string `name:"entityPath" short:"e" brief:"{CGenDaoBriefEntityPath}" d:"model/entity"` // Sub-directory under Path for entity struct files. + TplDaoTablePath string `name:"tplDaoTablePath" short:"t0" brief:"{CGenDaoBriefTplDaoTablePath}"` // Custom template file for dao table generation. + TplDaoIndexPath string `name:"tplDaoIndexPath" short:"t1" brief:"{CGenDaoBriefTplDaoIndexPath}"` // Custom template file for dao index generation. + TplDaoInternalPath string `name:"tplDaoInternalPath" short:"t2" brief:"{CGenDaoBriefTplDaoInternalPath}"` // Custom template file for dao internal generation. + TplDaoDoPath string `name:"tplDaoDoPath" short:"t3" brief:"{CGenDaoBriefTplDaoDoPathPath}"` // Custom template file for DO generation. + TplDaoEntityPath string `name:"tplDaoEntityPath" short:"t4" brief:"{CGenDaoBriefTplDaoEntityPath}"` // Custom template file for entity generation. + StdTime bool `name:"stdTime" short:"s" brief:"{CGenDaoBriefStdTime}" orphan:"true"` // Use stdlib time.Time instead of gtime.Time for time fields. + WithTime bool `name:"withTime" short:"w" brief:"{CGenDaoBriefWithTime}" orphan:"true"` // Add creation timestamp to generated file headers. + GJsonSupport bool `name:"gJsonSupport" short:"n" brief:"{CGenDaoBriefGJsonSupport}" orphan:"true"` // Use *gjson.Json instead of string for JSON fields. + OverwriteDao bool `name:"overwriteDao" short:"v" brief:"{CGenDaoBriefOverwriteDao}" orphan:"true"` // Overwrite existing dao files (both index and internal). + DescriptionTag bool `name:"descriptionTag" short:"c" brief:"{CGenDaoBriefDescriptionTag}" orphan:"true"`// Add description struct tag with field comment. + NoJsonTag bool `name:"noJsonTag" short:"k" brief:"{CGenDaoBriefNoJsonTag}" orphan:"true"` // Omit json struct tags from generated structs. + NoModelComment bool `name:"noModelComment" short:"m" brief:"{CGenDaoBriefNoModelComment}" orphan:"true"`// Omit inline comments from generated struct fields. + Clear bool `name:"clear" short:"a" brief:"{CGenDaoBriefClear}" orphan:"true"` // Delete generated files that no longer correspond to database tables. + GenTable bool `name:"genTable" short:"gt" brief:"{CGenDaoBriefGenTable}" orphan:"true"` // Enable generation of table field definition files. + SqlDir string `name:"sqlDir" short:"sd" brief:"{CGenDaoBriefSqlDir}"` // Directory of SQL DDL files for offline generation (no DB connection needed). + SqlType string `name:"sqlType" short:"st" brief:"{CGenDaoBriefSqlType}" d:"mysql"` // SQL dialect when using SqlDir (mysql, pgsql, mssql, oracle, sqlite). - TypeMapping map[DBFieldTypeName]CustomAttributeType `name:"typeMapping" short:"y" brief:"{CGenDaoBriefTypeMapping}" orphan:"true"` + // TypeMapping maps database field type names to custom Go types. + // For example, mapping "decimal" to "float64" or "uuid" to "uuid.UUID". + TypeMapping map[DBFieldTypeName]CustomAttributeType `name:"typeMapping" short:"y" brief:"{CGenDaoBriefTypeMapping}" orphan:"true"` + // FieldMapping maps specific table.field combinations to custom Go types. + // For example, mapping "user.balance" to "decimal.Decimal". FieldMapping map[DBTableFieldName]CustomAttributeType `name:"fieldMapping" short:"fm" brief:"{CGenDaoBriefFieldMapping}" orphan:"true"` - // internal usage purpose. + // genItems tracks all generated file paths and directories for cleanup purposes. genItems *CGenDaoInternalGenItems } + + // CGenDaoOutput is the output of the "gen dao" command (currently empty). CGenDaoOutput struct{} + // CGenDaoInternalInput extends CGenDaoInput with runtime-resolved fields + // used during the actual generation process. CGenDaoInternalInput struct { CGenDaoInput - DB gdb.DB - TableNames []string - NewTableNames []string - ShardingTableSet *gset.StrSet + DB gdb.DB // Database connection instance (nil in SQL file mode). + TableNames []string // Original table names from database or SQL files. + NewTableNames []string // Processed table names after prefix removal and sharding. + ShardingTableSet *gset.StrSet // Set of table names identified as sharding tables. + // TableFieldsMap stores pre-parsed table fields from SQL files. + // When this is set (SQL file mode), DB may be nil. + TableFieldsMap map[string]map[string]*gdb.TableField } - DBTableFieldName = string - DBFieldTypeName = string + + // DBTableFieldName is the fully-qualified field name in "table.field" format. + DBTableFieldName = string + // DBFieldTypeName is the database column type name (e.g., "varchar", "decimal"). + DBFieldTypeName = string + // CustomAttributeType defines a custom Go type mapping with its import path. CustomAttributeType struct { - Type string `brief:"custom attribute type name"` - Import string `brief:"custom import for this type"` + Type string `brief:"custom attribute type name"` // Go type name (e.g., "decimal.Decimal"). + Import string `brief:"custom import for this type"` // Go import path (e.g., "github.com/shopspring/decimal"). } ) var ( - createdAt = gtime.Now() - tplView = gview.New() + createdAt = gtime.Now() // Timestamp captured at program start, used in generated file headers. + tplView = gview.New() // Shared template view instance for rendering all Go file templates. + // defaultTypeMapping provides built-in type mappings from database types to Go types. + // User-provided TypeMapping takes precedence over these defaults. defaultTypeMapping = map[DBFieldTypeName]CustomAttributeType{ "decimal": { Type: "float64", @@ -111,7 +134,8 @@ var ( }, } - // tablewriter Options + // twRenderer configures the tablewriter to render without borders or separators, + // producing clean aligned text output for generated Go source code. twRenderer = tablewriter.WithRenderer(renderer.NewBlueprint(tw.Rendition{ Borders: tw.Border{Top: tw.Off, Bottom: tw.Off, Left: tw.Off, Right: tw.Off}, Settings: tw.Settings{ @@ -126,9 +150,17 @@ var ( }) ) +// Dao is the main entry point for the "gen dao" command. +// It dispatches to the appropriate generation mode based on input: +// - SQL file mode (SqlDir is set): generates from DDL files without database connection. +// - Link mode (Link is set): uses a direct database connection string. +// - Config mode: reads database configuration from the application config file. func (c CGenDao) Dao(ctx context.Context, in CGenDaoInput) (out *CGenDaoOutput, err error) { in.genItems = newCGenDaoInternalGenItems() - if in.Link != "" { + if in.SqlDir != "" { + // SQL file mode: generate from SQL DDL files without database connection. + doGenDaoFromSQLFiles(ctx, in) + } else if in.Link != "" { doGenDaoForArray(ctx, -1, in) } else if g.Cfg().Available(ctx) { v := g.Cfg().MustGet(ctx, CGenDaoConfig) @@ -147,7 +179,11 @@ func (c CGenDao) Dao(ctx context.Context, in CGenDaoInput) (out *CGenDaoOutput, return } -// doGenDaoForArray implements the "gen dao" command for configuration array. +// doGenDaoForArray implements the "gen dao" command for a single configuration entry. +// When index >= 0, it reads configuration from the array at that index. +// When index < 0, it uses the input as-is (for Link mode or single config mode). +// It performs the full generation pipeline: connect to DB, resolve tables, +// apply sharding patterns, and generate dao/table/do/entity files. func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) { var ( err error @@ -332,6 +368,10 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) { in.genItems.SetClear(in.Clear) } +// getImportPartContent analyzes the generated Go source code and builds the import block. +// It automatically detects usage of gtime.Time, time.Time, and gjson.Json in the source, +// and includes the corresponding import paths. Additional custom imports (from TypeMapping +// or FieldMapping) are appended and their dependencies are resolved via "go get" if needed. func getImportPartContent(ctx context.Context, source string, isDo bool, appendImports []string) string { var packageImportsArray = garray.NewStrArray() if isDo { @@ -385,6 +425,9 @@ func getImportPartContent(ctx context.Context, source string, isDo bool, appendI return packageImportsStr } +// assignDefaultVar sets the default template variables for datetime strings +// used in generated file headers. The creation timestamp is only included +// when WithTime is enabled in the input configuration. func assignDefaultVar(view *gview.View, in CGenDaoInternalInput) { var ( tplCreatedAtDatetimeStr string @@ -399,6 +442,8 @@ func assignDefaultVar(view *gview.View, in CGenDaoInternalInput) { }) } +// sortFieldKeyForDao returns field names sorted by their Index in the TableField map. +// This preserves the original column order as defined in the database table schema. func sortFieldKeyForDao(fieldMap map[string]*gdb.TableField) []string { names := make(map[int]string) for _, field := range fieldMap { @@ -423,6 +468,20 @@ func sortFieldKeyForDao(fieldMap map[string]*gdb.TableField) []string { return result } +// getTableFields retrieves table fields either from the pre-parsed TableFieldsMap (SQL file mode) +// or from the database connection. This abstracts the data source for generation functions. +func getTableFields(ctx context.Context, in CGenDaoInternalInput, tableName string) (map[string]*gdb.TableField, error) { + if in.TableFieldsMap != nil { + if fields, ok := in.TableFieldsMap[tableName]; ok { + return fields, nil + } + return nil, fmt.Errorf("table '%s' not found in SQL files", tableName) + } + return in.DB.TableFields(ctx, tableName) +} + +// getTemplateFromPathOrDefault returns the template content from the given file path. +// If the file path is empty or the file has no content, it falls back to the default template. func getTemplateFromPathOrDefault(filePath string, def string) string { if filePath != "" { if contents := gfile.GetContents(filePath); contents != "" { @@ -489,3 +548,130 @@ func filterTablesByPatterns(allTables []string, patterns []string) []string { } return result } + +// doGenDaoFromSQLFiles implements the "gen dao" command for SQL file mode. +// It parses DDL SQL files to obtain table structures without requiring a database connection. +func doGenDaoFromSQLFiles(ctx context.Context, in CGenDaoInput) { + if dirRealPath := gfile.RealPath(in.Path); dirRealPath == "" { + mlog.Fatalf(`path "%s" does not exist`, in.Path) + } + if dirRealPath := gfile.RealPath(in.SqlDir); dirRealPath == "" { + mlog.Fatalf(`SQL directory "%s" does not exist`, in.SqlDir) + } + + dialect := SQLDialect(strings.ToLower(in.SqlType)) + tableNames, tableFieldsMap := ParseSQLFilesFromDir(in.SqlDir, dialect) + + removePrefixArray := gstr.SplitAndTrim(in.RemovePrefix, ",") + + // Table filtering by name patterns. + if in.Tables != "" { + inputTables := gstr.SplitAndTrim(in.Tables, ",") + var hasPattern bool + for _, t := range inputTables { + if containsWildcard(t) { + hasPattern = true + break + } + } + if hasPattern { + tableNames = filterTablesByPatterns(tableNames, inputTables) + } else { + tableNames = inputTables + } + } + + // Table excluding. + if in.TablesEx != "" { + array := garray.NewStrArrayFrom(tableNames) + for _, p := range gstr.SplitAndTrim(in.TablesEx, ",") { + if containsWildcard(p) { + regPattern := "^" + patternToRegex(p) + "$" + for _, v := range array.Clone().Slice() { + if gregex.IsMatchString(regPattern, v) { + array.RemoveValue(v) + } + } + } else { + array.RemoveValue(p) + } + } + tableNames = array.Slice() + } + + // merge default typeMapping. + if in.TypeMapping == nil { + in.TypeMapping = defaultTypeMapping + } else { + for key, typeMapping := range defaultTypeMapping { + if _, ok := in.TypeMapping[key]; !ok { + in.TypeMapping[key] = typeMapping + } + } + } + + // Process table names (prefix removal, sharding, etc.) + var ( + newTableNames = make([]string, len(tableNames)) + shardingNewTableSet = gset.NewStrSet() + ) + sortedShardingPatterns := make([]string, len(in.ShardingPattern)) + copy(sortedShardingPatterns, in.ShardingPattern) + sort.Slice(sortedShardingPatterns, func(i, j int) bool { + return len(sortedShardingPatterns[i]) > len(sortedShardingPatterns[j]) + }) + for i, tableName := range tableNames { + newTableName := tableName + for _, v := range removePrefixArray { + newTableName = gstr.TrimLeftStr(newTableName, v, 1) + } + if len(sortedShardingPatterns) > 0 { + for _, pattern := range sortedShardingPatterns { + var ( + match []string + regPattern = gstr.Replace(pattern, "?", `(.+)`) + err error + ) + match, err = gregex.MatchString(regPattern, newTableName) + if err != nil { + mlog.Fatalf(`invalid sharding pattern "%s": %+v`, pattern, err) + } + if len(match) < 2 { + continue + } + newTableName = gstr.Replace(pattern, "?", "") + newTableName = gstr.Trim(newTableName, `_.-`) + if shardingNewTableSet.Contains(newTableName) { + tableNames[i] = "" + break + } + shardingNewTableSet.Add(in.Prefix + newTableName) + break + } + } + newTableName = in.Prefix + newTableName + if tableNames[i] != "" { + newTableNames[i] = newTableName + } + } + tableNames = garray.NewStrArrayFrom(tableNames).FilterEmpty().Slice() + newTableNames = garray.NewStrArrayFrom(newTableNames).FilterEmpty().Slice() + in.genItems.Scale() + + internalInput := CGenDaoInternalInput{ + CGenDaoInput: in, + DB: nil, + TableNames: tableNames, + NewTableNames: newTableNames, + ShardingTableSet: shardingNewTableSet, + TableFieldsMap: tableFieldsMap, + } + + // Generate all files using the same flow as database mode. + generateDao(ctx, internalInput) + generateTable(ctx, internalInput) + generateDo(ctx, internalInput) + generateEntity(ctx, internalInput) + + in.genItems.SetClear(in.Clear) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_clear.go b/cmd/gf/internal/cmd/gendao/gendao_clear.go index 181804641..eb4dbbb90 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_clear.go +++ b/cmd/gf/internal/cmd/gendao/gendao_clear.go @@ -13,6 +13,10 @@ import ( "github.com/gogf/gf/cmd/gf/v2/internal/utility/mlog" ) +// doClear performs cleanup of stale generated files across all generation items. +// It collects all generated file paths from all items, then for each item with +// Clear enabled, removes any .go files in its directories that are NOT in the +// generated file list. This ensures files for dropped/removed tables are cleaned up. func doClear(items *CGenDaoInternalGenItems) { var allGeneratedFilePaths = make([]string, 0) for _, item := range items.Items { @@ -29,6 +33,10 @@ func doClear(items *CGenDaoInternalGenItems) { } } +// doClearItem removes stale .go files for a single generation item. +// It scans all storage directories for .go files and deletes any file +// that is not in the allGeneratedFilePaths list (i.e., no longer corresponds +// to an existing database table). func doClearItem(item CGenDaoInternalGenItem, allGeneratedFilePaths []string) { var generatedFilePaths = make([]string, 0) for _, dirPath := range item.StorageDirPaths { diff --git a/cmd/gf/internal/cmd/gendao/gendao_dao.go b/cmd/gf/internal/cmd/gendao/gendao_dao.go index fe32b3653..0c8b01c22 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_dao.go +++ b/cmd/gf/internal/cmd/gendao/gendao_dao.go @@ -26,6 +26,9 @@ import ( "github.com/gogf/gf/cmd/gf/v2/internal/utility/utils" ) +// generateDao generates dao files (index + internal) for all tables in the input. +// It creates the dao directory structure and iterates over each table to generate +// individual dao files via generateDaoSingle. func generateDao(ctx context.Context, in CGenDaoInternalInput) { var ( dirPathDao = gfile.Join(in.Path, in.DaoPath) @@ -48,21 +51,20 @@ func generateDao(ctx context.Context, in CGenDaoInternalInput) { } } +// generateDaoSingleInput holds all parameters needed to generate dao files for a single table. type generateDaoSingleInput struct { CGenDaoInternalInput - // TableName specifies the table name of the table. - TableName string - // NewTableName specifies the prefix-stripped or custom edited name of the table. - NewTableName string - DirPathDao string - DirPathDaoInternal string - IsSharding bool + TableName string // Original table name as it exists in the database. + NewTableName string // Processed table name after prefix removal and sharding. + DirPathDao string // Directory path for the dao index files. + DirPathDaoInternal string // Directory path for the dao internal implementation files. + IsSharding bool // Whether this table is a sharding table (merged from multiple physical tables). } // generateDaoSingle generates the dao and model content of given table. func generateDaoSingle(ctx context.Context, in generateDaoSingleInput) { // Generating table data preparing. - fieldMap, err := in.DB.TableFields(ctx, in.TableName) + fieldMap, err := getTableFields(ctx, in.CGenDaoInternalInput, in.TableName) if err != nil { mlog.Fatalf(`fetching tables fields failed for table "%s": %+v`, in.TableName, err) } @@ -105,14 +107,21 @@ func generateDaoSingle(ctx context.Context, in generateDaoSingleInput) { }) } +// generateDaoIndexInput holds parameters for generating the dao index file. +// The index file provides the public API (exported struct and constructor) +// for accessing the DAO, delegating to the internal implementation. type generateDaoIndexInput struct { generateDaoSingleInput - TableNameCamelCase string - TableNameCamelLowerCase string - ImportPrefix string - FileName string + TableNameCamelCase string // CamelCase version of the table name (e.g., "UserDetail"). + TableNameCamelLowerCase string // camelCase version of the table name (e.g., "userDetail"). + ImportPrefix string // Go import path prefix for the dao package. + FileName string // Output file name (without extension). } +// generateDaoIndex generates the dao index file for a single table. +// The index file is the public-facing dao file that users import directly. +// It will NOT overwrite an existing file unless OverwriteDao is enabled, +// allowing users to customize the index file without losing changes. func generateDaoIndex(in generateDaoIndexInput) { path := filepath.FromSlash(gfile.Join(in.DirPathDao, in.FileName+".go")) // It should add path to result slice whenever it would generate the path file or not. @@ -147,15 +156,21 @@ func generateDaoIndex(in generateDaoIndexInput) { } } +// generateDaoInternalInput holds parameters for generating the dao internal file. +// The internal file contains the actual DAO implementation with column definitions +// and is always overwritten on regeneration. type generateDaoInternalInput struct { generateDaoSingleInput - TableNameCamelCase string - TableNameCamelLowerCase string - ImportPrefix string - FileName string - FieldMap map[string]*gdb.TableField + TableNameCamelCase string // CamelCase version of the table name. + TableNameCamelLowerCase string // camelCase version of the table name. + ImportPrefix string // Go import path prefix for the dao package. + FileName string // Output file name (without extension). + FieldMap map[string]*gdb.TableField // Map of column name to field metadata. } +// generateDaoInternal generates the dao internal implementation file for a single table. +// This file is always regenerated (overwritten) and contains the Columns struct definition +// with column name constants and their string value assignments. func generateDaoInternal(in generateDaoInternalInput) { var ( ctx = context.Background() diff --git a/cmd/gf/internal/cmd/gendao/gendao_do.go b/cmd/gf/internal/cmd/gendao/gendao_do.go index 83683a58d..42d4a720b 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_do.go +++ b/cmd/gf/internal/cmd/gendao/gendao_do.go @@ -22,6 +22,10 @@ import ( "github.com/gogf/gf/cmd/gf/v2/internal/utility/utils" ) +// generateDo generates DO (Data Object) files for all tables. +// DO structs use "any" type for all scalar fields (replacing concrete types), +// enabling flexible query building with the g.Meta `orm:"do:true"` tag. +// Pointer, slice, and map types are preserved as-is. func generateDo(ctx context.Context, in CGenDaoInternalInput) { var dirPathDo = filepath.FromSlash(gfile.Join(in.Path, in.DoPath)) in.genItems.AppendDirPath(dirPathDo) @@ -30,7 +34,7 @@ func generateDo(ctx context.Context, in CGenDaoInternalInput) { in.NoModelComment = false // Model content. for i, tableName := range in.TableNames { - fieldMap, err := in.DB.TableFields(ctx, tableName) + fieldMap, err := getTableFields(ctx, in, tableName) if err != nil { mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", tableName, err) } @@ -75,6 +79,9 @@ func generateDo(ctx context.Context, in CGenDaoInternalInput) { } } +// generateDoContent renders the DO file content using the template engine. +// It assembles template variables including package imports, struct definition, +// and metadata, then parses the DO template to produce the final file content. func generateDoContent( ctx context.Context, in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string, ) string { diff --git a/cmd/gf/internal/cmd/gendao/gendao_entity.go b/cmd/gf/internal/cmd/gendao/gendao_entity.go index 8b059e810..6fd3e3d51 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_entity.go +++ b/cmd/gf/internal/cmd/gendao/gendao_entity.go @@ -20,12 +20,15 @@ import ( "github.com/gogf/gf/cmd/gf/v2/internal/utility/utils" ) +// generateEntity generates entity struct files for all tables. +// Entity structs represent database table rows with concrete Go types, +// including orm tags for field-to-column mapping and json tags for serialization. func generateEntity(ctx context.Context, in CGenDaoInternalInput) { var dirPathEntity = gfile.Join(in.Path, in.EntityPath) in.genItems.AppendDirPath(dirPathEntity) // Model content. for i, tableName := range in.TableNames { - fieldMap, err := in.DB.TableFields(ctx, tableName) + fieldMap, err := getTableFields(ctx, in, tableName) if err != nil { mlog.Fatalf("fetching tables fields failed for table '%s':\n%v", tableName, err) } @@ -60,6 +63,9 @@ func generateEntity(ctx context.Context, in CGenDaoInternalInput) { } } +// generateEntityContent renders the entity file content using the template engine. +// It assembles template variables and parses the entity template to produce +// the final Go source file content with proper imports and struct definition. func generateEntityContent( ctx context.Context, in CGenDaoInternalInput, tableName, tableNameCamelCase, structDefine string, appendImports []string, ) string { diff --git a/cmd/gf/internal/cmd/gendao/gendao_gen_item.go b/cmd/gf/internal/cmd/gendao/gendao_gen_item.go index fcb7bf9be..7876c7d11 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_gen_item.go +++ b/cmd/gf/internal/cmd/gendao/gendao_gen_item.go @@ -7,17 +7,25 @@ package gendao type ( + // CGenDaoInternalGenItems tracks generation state across multiple configuration entries. + // Each configuration entry (e.g., different database links in the config array) + // gets its own CGenDaoInternalGenItem via Scale(). The index field points to the + // current active item. CGenDaoInternalGenItems struct { - index int - Items []CGenDaoInternalGenItem + index int // Index of the current active generation item. + Items []CGenDaoInternalGenItem // List of all generation items, one per config entry. } + + // CGenDaoInternalGenItem tracks generated files and directories for a single + // configuration entry. Used by the Clear feature to identify and remove stale files. CGenDaoInternalGenItem struct { - Clear bool - StorageDirPaths []string - GeneratedFilePaths []string + Clear bool // Whether to clear stale files for this item. + StorageDirPaths []string // Directories where generated files are stored (dao, do, entity, table). + GeneratedFilePaths []string // All file paths generated in this run. } ) +// newCGenDaoInternalGenItems creates a new generation items tracker with an empty item list. func newCGenDaoInternalGenItems() *CGenDaoInternalGenItems { return &CGenDaoInternalGenItems{ index: -1, @@ -25,6 +33,8 @@ func newCGenDaoInternalGenItems() *CGenDaoInternalGenItems { } } +// Scale adds a new generation item and advances the index to it. +// Must be called once per configuration entry before generating files. func (i *CGenDaoInternalGenItems) Scale() { i.Items = append(i.Items, CGenDaoInternalGenItem{ StorageDirPaths: make([]string, 0), @@ -34,10 +44,12 @@ func (i *CGenDaoInternalGenItems) Scale() { i.index++ } +// SetClear enables or disables the clear (stale file removal) flag for the current item. func (i *CGenDaoInternalGenItems) SetClear(clear bool) { i.Items[i.index].Clear = clear } +// AppendDirPath records a directory path used for storing generated files in the current item. func (i *CGenDaoInternalGenItems) AppendDirPath(storageDirPath string) { i.Items[i.index].StorageDirPaths = append( i.Items[i.index].StorageDirPaths, @@ -45,6 +57,7 @@ func (i *CGenDaoInternalGenItems) AppendDirPath(storageDirPath string) { ) } +// AppendGeneratedFilePath records a file path that was generated in the current item. func (i *CGenDaoInternalGenItems) AppendGeneratedFilePath(generatedFilePath string) { i.Items[i.index].GeneratedFilePaths = append( i.Items[i.index].GeneratedFilePaths, diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser.go new file mode 100644 index 000000000..7092e4de1 --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser.go @@ -0,0 +1,1030 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "sort" + "strings" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/os/gfile" + + "github.com/gogf/gf/cmd/gf/v2/internal/utility/mlog" +) + +// SQLDialect defines supported SQL dialect types. +type SQLDialect string + +const ( + SQLDialectMySQL SQLDialect = "mysql" + SQLDialectPgSQL SQLDialect = "pgsql" + SQLDialectMSSQL SQLDialect = "mssql" + SQLDialectOracle SQLDialect = "oracle" + SQLDialectSQLite SQLDialect = "sqlite" + SQLDialectClickHouse SQLDialect = "clickhouse" +) + +// SQLStatementType identifies the type of a DDL statement. +type SQLStatementType int + +const ( + SQLStatementUnknown SQLStatementType = iota + SQLStatementCreateTable // CREATE TABLE + SQLStatementAlterTable // ALTER TABLE + SQLStatementDropTable // DROP TABLE + SQLStatementRenameTable // RENAME TABLE / ALTER TABLE ... RENAME TO + SQLStatementComment // COMMENT ON COLUMN / sp_addextendedproperty +) + +// SQLParser is the interface for parsing SQL DDL files into table field definitions. +// Each parser must implement CREATE TABLE parsing. ALTER TABLE, DROP TABLE, and +// comment handling are optional and can be delegated to the common layer. +type SQLParser interface { + // ParseCreateTable parses a single CREATE TABLE statement and returns table name and fields. + ParseCreateTable(stmt string) (tableName string, fields map[string]*gdb.TableField, err error) + + // ParseAlterTable parses a single ALTER TABLE statement and applies changes to existing tables. + // Returns the affected table name. + ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error + + // ParseComment parses a comment statement (COMMENT ON COLUMN / sp_addextendedproperty) + // and applies the comment to the corresponding field. + ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) +} + +// GetSQLParser returns the appropriate SQL parser for the given dialect. +func GetSQLParser(dialect SQLDialect) SQLParser { + switch dialect { + case SQLDialectMySQL: + return &MySQLParser{} + case SQLDialectPgSQL: + return &PgSQLParser{} + case SQLDialectMSSQL: + return &MSSQLParser{} + case SQLDialectOracle: + return &OracleParser{} + case SQLDialectSQLite: + return &SQLiteParser{} + default: + return nil + } +} + +// ParseSQLFilesFromDir parses all .sql files from the given directory using the specified +// dialect parser. Files are processed in sorted order (by filename) to ensure correct +// incremental migration order: CREATE TABLE first, then ALTER TABLE modifications. +func ParseSQLFilesFromDir(sqlDir string, dialect SQLDialect) ( + tableNames []string, + tableFieldsMap map[string]map[string]*gdb.TableField, +) { + parser := GetSQLParser(dialect) + if parser == nil { + mlog.Fatalf(`unsupported SQL dialect "%s"`, dialect) + } + + sqlFiles, err := gfile.ScanDirFile(sqlDir, "*.sql", true) + if err != nil { + mlog.Fatalf(`scanning SQL directory "%s" failed: %+v`, sqlDir, err) + } + if len(sqlFiles) == 0 { + mlog.Fatalf(`no .sql files found in directory "%s"`, sqlDir) + } + + // Sort files by name to ensure deterministic migration order. + // This supports naming conventions like: + // V001_create_tables.sql, V002_add_columns.sql, V003_modify_columns.sql + // 001_init.sql, 002_alter.sql + // 2024-01-01_create.sql, 2024-01-15_alter.sql + sort.Strings(sqlFiles) + + tableFieldsMap = make(map[string]map[string]*gdb.TableField) + + for _, sqlFile := range sqlFiles { + content := gfile.GetContents(sqlFile) + if content == "" { + continue + } + err := processSQL(parser, content, tableFieldsMap) + if err != nil { + mlog.Fatalf(`parsing SQL file "%s" failed: %+v`, sqlFile, err) + } + } + + for tableName := range tableFieldsMap { + tableNames = append(tableNames, tableName) + } + sort.Strings(tableNames) + return +} + +// processSQL processes all SQL statements in a single file content, +// dispatching each statement to the appropriate handler based on its type. +func processSQL( + parser SQLParser, + sql string, + tables map[string]map[string]*gdb.TableField, +) error { + statements := splitSQLStatements(sql) + for _, stmt := range statements { + stmtType := classifyStatement(stmt) + switch stmtType { + case SQLStatementCreateTable: + tableName, fields, err := parser.ParseCreateTable(stmt) + if err != nil { + return err + } + if tableName != "" && len(fields) > 0 { + tables[tableName] = fields + } + + case SQLStatementAlterTable: + err := parser.ParseAlterTable(stmt, tables) + if err != nil { + return err + } + + case SQLStatementDropTable: + applyDropTable(stmt, tables) + + case SQLStatementRenameTable: + applyRenameTable(stmt, tables) + + case SQLStatementComment: + parser.ParseComment(stmt, tables) + } + } + return nil +} + +// classifyStatement determines the type of a SQL DDL statement. +func classifyStatement(stmt string) SQLStatementType { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + // Remove leading block comments + for strings.HasPrefix(upper, "/*") { + end := strings.Index(upper, "*/") + if end < 0 { + break + } + upper = strings.TrimSpace(upper[end+2:]) + } + + words := strings.Fields(upper) + if len(words) < 2 { + return SQLStatementUnknown + } + + switch words[0] { + case "CREATE": + // CREATE [TEMPORARY|TEMP] TABLE + for _, w := range words[1:] { + if w == "TABLE" { + return SQLStatementCreateTable + } + if w != "TEMPORARY" && w != "TEMP" && w != "GLOBAL" && w != "LOCAL" && + w != "UNLOGGED" { + break + } + } + + case "ALTER": + if words[1] == "TABLE" { + // Check if it's ALTER TABLE ... RENAME TO + if strings.Contains(upper, "RENAME TO") || strings.Contains(upper, "RENAME AS") { + return SQLStatementRenameTable + } + return SQLStatementAlterTable + } + + case "DROP": + if words[1] == "TABLE" { + return SQLStatementDropTable + } + + case "RENAME": + // RENAME TABLE old TO new (MySQL syntax) + if words[1] == "TABLE" { + return SQLStatementRenameTable + } + + case "COMMENT": + // COMMENT ON COLUMN / COMMENT ON TABLE + if len(words) >= 3 && words[1] == "ON" { + return SQLStatementComment + } + + case "EXEC", "EXECUTE": + // EXEC sp_addextendedproperty (MSSQL comments) + if strings.Contains(upper, "SP_ADDEXTENDEDPROPERTY") && + strings.Contains(upper, "MS_DESCRIPTION") { + return SQLStatementComment + } + } + + return SQLStatementUnknown +} + +// applyDropTable removes a table from the tables map. +// Handles: DROP TABLE [IF EXISTS] [schema.]table_name +func applyDropTable(stmt string, tables map[string]map[string]*gdb.TableField) { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + upper = strings.TrimPrefix(upper, "DROP") + upper = strings.TrimSpace(upper) + upper = strings.TrimPrefix(upper, "TABLE") + upper = strings.TrimSpace(upper) + if strings.HasPrefix(upper, "IF EXISTS") { + upper = strings.TrimPrefix(upper, "IF EXISTS") + } + + remaining := stmt[len(stmt)-len(upper):] + remaining = strings.TrimSpace(remaining) + + // May be comma-separated list: DROP TABLE t1, t2, t3 + for _, name := range strings.Split(remaining, ",") { + name = strings.TrimSpace(name) + parts := strings.Split(name, ".") + tableName := unquoteIdentifier(parts[len(parts)-1]) + delete(tables, tableName) + } +} + +// applyRenameTable renames a table in the tables map. +// Handles: +// - RENAME TABLE old TO new (MySQL) +// - ALTER TABLE old RENAME TO new (PostgreSQL, SQLite, Oracle) +func applyRenameTable(stmt string, tables map[string]map[string]*gdb.TableField) { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + words := strings.Fields(stmt) + upperWords := strings.Fields(upper) + + if upperWords[0] == "RENAME" && len(upperWords) >= 5 && upperWords[1] == "TABLE" { + // RENAME TABLE old_name TO new_name + oldName := unquoteIdentifier(words[2]) + newName := unquoteIdentifier(words[4]) + if fields, ok := tables[oldName]; ok { + tables[newName] = fields + delete(tables, oldName) + } + } else if upperWords[0] == "ALTER" && len(upperWords) >= 6 && upperWords[1] == "TABLE" { + // ALTER TABLE old_name RENAME TO new_name + oldName := unquoteIdentifier(words[2]) + for i, w := range upperWords { + if w == "RENAME" && i+2 < len(upperWords) && + (upperWords[i+1] == "TO" || upperWords[i+1] == "AS") { + newName := unquoteIdentifier(words[i+2]) + if fields, ok := tables[oldName]; ok { + tables[newName] = fields + delete(tables, oldName) + } + return + } + } + } +} + +// parseAlterTableCommon provides common ALTER TABLE parsing logic that works for +// most SQL dialects. Individual parsers can call this or override with dialect-specific logic. +// +// Supported operations: +// - ADD [COLUMN] column_name type [constraints] +// - DROP [COLUMN] column_name +// - MODIFY [COLUMN] column_name type [constraints] (MySQL, Oracle) +// - ALTER [COLUMN] column_name TYPE type / SET / DROP (PostgreSQL) +// - CHANGE [COLUMN] old_name new_name type [constraints] (MySQL) +// - ADD PRIMARY KEY (col1, col2, ...) +// - DROP PRIMARY KEY +// - RENAME COLUMN old_name TO new_name +func parseAlterTableCommon( + stmt string, + tables map[string]map[string]*gdb.TableField, + columnParser func(def string, index int) (*gdb.TableField, error), +) error { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + words := strings.Fields(stmt) + upperWords := strings.Fields(upper) + + if len(upperWords) < 4 || upperWords[0] != "ALTER" || upperWords[1] != "TABLE" { + return nil + } + + tableName := unquoteIdentifier(words[2]) + fields, exists := tables[tableName] + if !exists { + // Table not yet created, skip + return nil + } + + // The rest after ALTER TABLE tableName + restIdx := 3 + if restIdx >= len(upperWords) { + return nil + } + + // Process the ALTER TABLE actions. Some dialects allow multiple actions separated by commas + // but we handle one action at a time for simplicity and split multi-action later. + return processAlterAction(upperWords, words, restIdx, fields, tableName, tables, columnParser) +} + +func processAlterAction( + upperWords, words []string, + startIdx int, + fields map[string]*gdb.TableField, + tableName string, + tables map[string]map[string]*gdb.TableField, + columnParser func(def string, index int) (*gdb.TableField, error), +) error { + if startIdx >= len(upperWords) { + return nil + } + + action := upperWords[startIdx] + switch action { + case "ADD": + return processAlterAdd(upperWords, words, startIdx+1, fields, columnParser) + + case "DROP": + processAlterDrop(upperWords, words, startIdx+1, fields) + + case "MODIFY": + // MODIFY [COLUMN] column_name type [constraints] (MySQL, Oracle) + return processAlterModify(upperWords, words, startIdx+1, fields, columnParser) + + case "CHANGE": + // CHANGE [COLUMN] old_name new_name type [constraints] (MySQL) + return processAlterChange(upperWords, words, startIdx+1, fields, columnParser) + + case "ALTER": + // ALTER [COLUMN] column_name ... (PostgreSQL: SET DEFAULT, DROP DEFAULT, SET NOT NULL, etc.) + processAlterColumn(upperWords, words, startIdx+1, fields) + + case "RENAME": + // RENAME COLUMN old_name TO new_name + processAlterRenameColumn(upperWords, words, startIdx+1, fields) + } + + return nil +} + +// processAlterAdd handles ALTER TABLE ... ADD [COLUMN] ... or ADD PRIMARY KEY ... +func processAlterAdd( + upperWords, words []string, + idx int, + fields map[string]*gdb.TableField, + columnParser func(def string, index int) (*gdb.TableField, error), +) error { + if idx >= len(upperWords) { + return nil + } + + // Skip optional COLUMN keyword + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + + // ADD PRIMARY KEY (col1, col2) + if upperWords[idx] == "PRIMARY" || upperWords[idx] == "UNIQUE" || + upperWords[idx] == "INDEX" || upperWords[idx] == "KEY" || + upperWords[idx] == "CONSTRAINT" || upperWords[idx] == "FOREIGN" { + if strings.Contains(strings.Join(upperWords[idx:], " "), "PRIMARY KEY") { + fullStmt := strings.Join(words[idx:], " ") + pkCols := findPrimaryKeysFromConstraints([]string{fullStmt}) + for _, pkCol := range pkCols { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + } + } + return nil + } + + if colIdx >= len(words) { + return nil + } + + // ADD [COLUMN] column_def ... + // Build the column definition from remaining words + def := strings.Join(words[colIdx:], " ") + nextIndex := getNextFieldIndex(fields) + field, err := columnParser(def, nextIndex) + if err != nil { + return nil // skip unparseable + } + if field != nil { + fields[field.Name] = field + } + return nil +} + +// processAlterDrop handles ALTER TABLE ... DROP [COLUMN] column_name or DROP PRIMARY KEY +func processAlterDrop(upperWords, words []string, idx int, fields map[string]*gdb.TableField) { + if idx >= len(upperWords) { + return + } + + // DROP PRIMARY KEY + if upperWords[idx] == "PRIMARY" { + for _, f := range fields { + if f.Key == "PRI" { + f.Key = "" + } + } + return + } + + // DROP [COLUMN] column_name [CASCADE|RESTRICT] + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + if colIdx >= len(words) { + return + } + + colName := unquoteIdentifier(words[colIdx]) + delete(fields, colName) + + // Reindex remaining fields + reindexFields(fields) +} + +// processAlterModify handles ALTER TABLE ... MODIFY [COLUMN] column_name type [constraints] +func processAlterModify( + upperWords, words []string, + idx int, + fields map[string]*gdb.TableField, + columnParser func(def string, index int) (*gdb.TableField, error), +) error { + if idx >= len(upperWords) { + return nil + } + + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + if colIdx >= len(words) { + return nil + } + + colName := unquoteIdentifier(words[colIdx]) + def := strings.Join(words[colIdx:], " ") + + existingIndex := 0 + if existing, ok := fields[colName]; ok { + existingIndex = existing.Index + } + + field, err := columnParser(def, existingIndex) + if err != nil { + return nil + } + if field != nil { + // Preserve the original index + if existing, ok := fields[colName]; ok { + field.Index = existing.Index + } + fields[field.Name] = field + } + return nil +} + +// processAlterChange handles ALTER TABLE ... CHANGE [COLUMN] old_name new_name type [constraints] +func processAlterChange( + upperWords, words []string, + idx int, + fields map[string]*gdb.TableField, + columnParser func(def string, index int) (*gdb.TableField, error), +) error { + if idx >= len(upperWords) { + return nil + } + + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + if colIdx+1 >= len(words) { + return nil + } + + oldName := unquoteIdentifier(words[colIdx]) + // New definition starts from the new column name + def := strings.Join(words[colIdx+1:], " ") + + existingIndex := 0 + if existing, ok := fields[oldName]; ok { + existingIndex = existing.Index + } + + field, err := columnParser(def, existingIndex) + if err != nil { + return nil + } + if field != nil { + // Remove old field + delete(fields, oldName) + if existing, ok := fields[oldName]; ok { + field.Index = existing.Index + } else { + field.Index = existingIndex + } + fields[field.Name] = field + } + return nil +} + +// processAlterColumn handles ALTER TABLE ... ALTER [COLUMN] column_name ... +// PostgreSQL style: SET DEFAULT, DROP DEFAULT, SET NOT NULL, DROP NOT NULL, TYPE +func processAlterColumn(upperWords, words []string, idx int, fields map[string]*gdb.TableField) { + if idx >= len(upperWords) { + return + } + + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + if colIdx >= len(words) { + return + } + + colName := unquoteIdentifier(words[colIdx]) + field, ok := fields[colName] + if !ok { + return + } + + actionIdx := colIdx + 1 + if actionIdx >= len(upperWords) { + return + } + + switch upperWords[actionIdx] { + case "SET": + if actionIdx+1 < len(upperWords) { + switch upperWords[actionIdx+1] { + case "NOT": + // SET NOT NULL + if actionIdx+2 < len(upperWords) && upperWords[actionIdx+2] == "NULL" { + field.Null = false + } + case "DEFAULT": + // SET DEFAULT value + if actionIdx+2 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[actionIdx+2:], " ")) + field.Default = defaultVal + } + case "DATA": + // SET DATA TYPE type_name (PostgreSQL) + if actionIdx+2 < len(upperWords) && upperWords[actionIdx+2] == "TYPE" && + actionIdx+3 < len(words) { + field.Type = strings.Join(words[actionIdx+3:], " ") + } + } + } + case "DROP": + if actionIdx+1 < len(upperWords) { + switch upperWords[actionIdx+1] { + case "NOT": + // DROP NOT NULL + if actionIdx+2 < len(upperWords) && upperWords[actionIdx+2] == "NULL" { + field.Null = true + } + case "DEFAULT": + // DROP DEFAULT + field.Default = nil + } + } + case "TYPE": + // TYPE new_type (PostgreSQL: ALTER COLUMN col TYPE varchar(200)) + if actionIdx+1 < len(words) { + // Collect the type, which may include USING clause + typeParts := make([]string, 0) + for j := actionIdx + 1; j < len(words); j++ { + if strings.ToUpper(words[j]) == "USING" { + break + } + typeParts = append(typeParts, words[j]) + } + if len(typeParts) > 0 { + field.Type = strings.Join(typeParts, " ") + } + } + } +} + +// processAlterRenameColumn handles ALTER TABLE ... RENAME COLUMN old TO new +func processAlterRenameColumn(upperWords, words []string, idx int, fields map[string]*gdb.TableField) { + if idx >= len(upperWords) { + return + } + + colIdx := idx + if upperWords[colIdx] == "COLUMN" { + colIdx++ + } + if colIdx+2 >= len(words) { + return + } + + // RENAME [COLUMN] old_name TO new_name + oldName := unquoteIdentifier(words[colIdx]) + // Find "TO" + for i := colIdx + 1; i < len(upperWords)-1; i++ { + if upperWords[i] == "TO" { + newName := unquoteIdentifier(words[i+1]) + if field, ok := fields[oldName]; ok { + field.Name = newName + fields[newName] = field + delete(fields, oldName) + } + return + } + } +} + +// getNextFieldIndex returns the next available field index. +func getNextFieldIndex(fields map[string]*gdb.TableField) int { + maxIndex := -1 + for _, f := range fields { + if f.Index > maxIndex { + maxIndex = f.Index + } + } + return maxIndex + 1 +} + +// reindexFields re-assigns sequential indices to all fields after a deletion. +func reindexFields(fields map[string]*gdb.TableField) { + type indexedField struct { + name string + index int + } + sorted := make([]indexedField, 0, len(fields)) + for name, f := range fields { + sorted = append(sorted, indexedField{name: name, index: f.Index}) + } + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].index < sorted[j].index + }) + for i, sf := range sorted { + fields[sf.name].Index = i + } +} + +// splitSQLStatements splits SQL content into individual statements by semicolons, +// handling quoted strings and parentheses properly. +func splitSQLStatements(sql string) []string { + var ( + statements []string + current strings.Builder + inSingle bool + inDouble bool + inBlock bool // block comment + depth int + prev byte + ) + for i := 0; i < len(sql); i++ { + ch := sql[i] + switch { + case inBlock: + current.WriteByte(ch) + if ch == '/' && prev == '*' { + inBlock = false + } + case ch == '/' && i+1 < len(sql) && sql[i+1] == '*' && !inSingle && !inDouble: + inBlock = true + current.WriteByte(ch) + case ch == '-' && i+1 < len(sql) && sql[i+1] == '-' && !inSingle && !inDouble: + // Line comment - skip to end of line + for i < len(sql) && sql[i] != '\n' { + i++ + } + case ch == '\'' && !inDouble: + inSingle = !inSingle + current.WriteByte(ch) + case ch == '"' && !inSingle: + inDouble = !inDouble + current.WriteByte(ch) + case ch == '(' && !inSingle && !inDouble: + depth++ + current.WriteByte(ch) + case ch == ')' && !inSingle && !inDouble: + depth-- + current.WriteByte(ch) + case ch == ';' && !inSingle && !inDouble && depth == 0: + stmt := strings.TrimSpace(current.String()) + if stmt != "" { + statements = append(statements, stmt) + } + current.Reset() + default: + current.WriteByte(ch) + } + prev = ch + } + // Last statement without semicolon + if stmt := strings.TrimSpace(current.String()); stmt != "" { + statements = append(statements, stmt) + } + return statements +} + +// unquoteIdentifier removes quotes from SQL identifiers. +// Handles: `name`, "name", [name], 'name' +func unquoteIdentifier(name string) string { + name = strings.TrimSpace(name) + if len(name) < 2 { + return name + } + switch { + case name[0] == '`' && name[len(name)-1] == '`': + return name[1 : len(name)-1] + case name[0] == '"' && name[len(name)-1] == '"': + return name[1 : len(name)-1] + case name[0] == '[' && name[len(name)-1] == ']': + return name[1 : len(name)-1] + } + return name +} + +// extractTableName extracts the table name from a CREATE TABLE statement header. +// It handles: CREATE TABLE name, CREATE TABLE IF NOT EXISTS name, +// CREATE TABLE schema.name, etc. +func extractTableName(header string) string { + header = strings.TrimSpace(header) + // Remove CREATE [TEMPORARY] TABLE [IF NOT EXISTS] + upper := strings.ToUpper(header) + upper = strings.TrimPrefix(upper, "CREATE") + upper = strings.TrimSpace(upper) + if strings.HasPrefix(upper, "TEMPORARY") || strings.HasPrefix(upper, "TEMP") { + idx := strings.Index(upper, "TABLE") + if idx >= 0 { + upper = upper[idx:] + } + } + upper = strings.TrimPrefix(upper, "TABLE") + upper = strings.TrimSpace(upper) + if strings.HasPrefix(upper, "IF NOT EXISTS") { + upper = strings.TrimPrefix(upper, "IF NOT EXISTS") + } + + // Now get the actual name from original string at same offset + remaining := header[len(header)-len(upper):] + remaining = strings.TrimSpace(remaining) + + // Handle schema.table + parts := strings.Split(remaining, ".") + tableName := parts[len(parts)-1] + tableName = strings.TrimSpace(tableName) + + return unquoteIdentifier(tableName) +} + +// splitColumns splits the column definitions part of a CREATE TABLE statement +// into individual column/constraint definitions, properly handling nested parentheses. +func splitColumns(body string) []string { + var ( + result []string + current strings.Builder + depth int + inStr bool + quote byte + ) + for i := 0; i < len(body); i++ { + ch := body[i] + switch { + case inStr: + current.WriteByte(ch) + if ch == quote && (i+1 >= len(body) || body[i+1] != quote) { + inStr = false + } + case ch == '\'' || ch == '"': + inStr = true + quote = ch + current.WriteByte(ch) + case ch == '(': + depth++ + current.WriteByte(ch) + case ch == ')': + depth-- + current.WriteByte(ch) + case ch == ',' && depth == 0: + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + current.Reset() + default: + current.WriteByte(ch) + } + } + if s := strings.TrimSpace(current.String()); s != "" { + result = append(result, s) + } + return result +} + +// isConstraintKeyword checks if the given word starts a table-level constraint +// (not a column definition). +func isConstraintKeyword(word string) bool { + upper := strings.ToUpper(word) + switch upper { + case "PRIMARY", "UNIQUE", "INDEX", "KEY", "CHECK", "FOREIGN", "CONSTRAINT", + "CLUSTERED", "NONCLUSTERED", "SPATIAL", "FULLTEXT": + return true + } + return false +} + +// findPrimaryKeysFromConstraints scans constraint definitions for PRIMARY KEY +// and returns the column names that form the primary key. +func findPrimaryKeysFromConstraints(columnDefs []string) []string { + var pkColumns []string + for _, def := range columnDefs { + upper := strings.ToUpper(strings.TrimSpace(def)) + if !strings.Contains(upper, "PRIMARY KEY") { + continue + } + // Extract column names from PRIMARY KEY (col1, col2, ...) + idx := strings.Index(upper, "PRIMARY KEY") + rest := def[idx+len("PRIMARY KEY"):] + rest = strings.TrimSpace(rest) + // Skip optional CLUSTERED/NONCLUSTERED keyword (MSSQL). + upperRest := strings.ToUpper(rest) + if strings.HasPrefix(upperRest, "CLUSTERED") { + rest = strings.TrimSpace(rest[len("CLUSTERED"):]) + } else if strings.HasPrefix(upperRest, "NONCLUSTERED") { + rest = strings.TrimSpace(rest[len("NONCLUSTERED"):]) + } + if len(rest) > 0 && rest[0] == '(' { + end := strings.Index(rest, ")") + if end > 0 { + cols := rest[1:end] + for _, col := range strings.Split(cols, ",") { + col = strings.TrimSpace(col) + // Remove ASC/DESC + parts := strings.Fields(col) + if len(parts) > 0 { + pkColumns = append(pkColumns, unquoteIdentifier(parts[0])) + } + } + } + } + } + return pkColumns +} + +// extractBodyAndTrailing splits CREATE TABLE ... (...) ... into body and trailing parts. +// It returns the content inside the outermost parentheses and anything after. +func extractBodyAndTrailing(sql string) (body, trailing string, ok bool) { + // Find the first '(' that starts the column definitions + depth := 0 + startIdx := -1 + endIdx := -1 + inStr := false + var quote byte + for i := 0; i < len(sql); i++ { + ch := sql[i] + if inStr { + if ch == quote && (i+1 >= len(sql) || sql[i+1] != quote) { + inStr = false + } + continue + } + if ch == '\'' || ch == '"' { + inStr = true + quote = ch + continue + } + if ch == '(' { + if depth == 0 { + startIdx = i + } + depth++ + } else if ch == ')' { + depth-- + if depth == 0 { + endIdx = i + break + } + } + } + if startIdx < 0 || endIdx < 0 { + return "", "", false + } + body = sql[startIdx+1 : endIdx] + trailing = strings.TrimSpace(sql[endIdx+1:]) + return body, trailing, true +} + +// extractDefaultValue extracts the default value string from a column definition fragment. +// Returns the default value and the remaining string after the default clause. +func extractDefaultValue(s string) (defaultVal any, rest string) { + upper := strings.ToUpper(strings.TrimSpace(s)) + if !strings.HasPrefix(upper, "DEFAULT") { + return nil, s + } + s = strings.TrimSpace(s[7:]) // skip "DEFAULT" + if len(s) == 0 { + return nil, "" + } + + // NULL + if strings.HasPrefix(strings.ToUpper(s), "NULL") { + return nil, strings.TrimSpace(s[4:]) + } + + // Quoted string + if s[0] == '\'' { + end := 1 + for end < len(s) { + if s[end] == '\'' { + if end+1 < len(s) && s[end+1] == '\'' { + end += 2 + continue + } + val := s[1:end] + val = strings.ReplaceAll(val, "''", "'") + return val, strings.TrimSpace(s[end+1:]) + } + end++ + } + return s[1:], "" + } + + // Parenthesized expression like (getdate()), ((0)) + if s[0] == '(' { + depth := 0 + for i := 0; i < len(s); i++ { + if s[i] == '(' { + depth++ + } else if s[i] == ')' { + depth-- + if depth == 0 { + return s[:i+1], strings.TrimSpace(s[i+1:]) + } + } + } + return s, "" + } + + // Unquoted value (number, keyword like CURRENT_TIMESTAMP, etc.) + parts := strings.Fields(s) + if len(parts) > 0 { + val := parts[0] + // Remove trailing comma if any + val = strings.TrimRight(val, ",") + rest = strings.TrimSpace(s[len(parts[0]):]) + return val, rest + } + return nil, "" +} + +// mysqlTokenize performs simple tokenization of a column definition string, +// respecting quoted identifiers and parenthesized type parameters. +func mysqlTokenize(def string) []string { + var ( + tokens []string + current strings.Builder + inStr bool + quote byte + depth int + ) + for i := 0; i < len(def); i++ { + ch := def[i] + switch { + case inStr: + current.WriteByte(ch) + if ch == quote && (i+1 >= len(def) || def[i+1] != quote) { + inStr = false + } + case ch == '\'' || ch == '"' || ch == '`': + if current.Len() == 0 || depth > 0 { + inStr = true + quote = ch + } + current.WriteByte(ch) + case ch == '(': + depth++ + current.WriteByte(ch) + case ch == ')': + depth-- + current.WriteByte(ch) + case (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r') && depth == 0: + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + default: + current.WriteByte(ch) + } + } + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + return tokens +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql.go new file mode 100644 index 000000000..025f2221a --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql.go @@ -0,0 +1,211 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" +) + +// MSSQLParser implements SQLParser for SQL Server (T-SQL) DDL. +type MSSQLParser struct{} + +// ParseCreateTable parses a single MSSQL CREATE TABLE statement. +func (p *MSSQLParser) ParseCreateTable(stmt string) (string, map[string]*gdb.TableField, error) { + body, _, ok := extractBodyAndTrailing(stmt) + if !ok { + return "", nil, nil + } + + parenIdx := strings.Index(stmt, "(") + header := stmt[:parenIdx] + tableName := extractTableName(header) + if tableName == "" { + return "", nil, fmt.Errorf("cannot extract table name from: %s", header) + } + + columnDefs := splitColumns(body) + fields := make(map[string]*gdb.TableField) + pkColumns := findPrimaryKeysFromConstraints(columnDefs) + + fieldIndex := 0 + for _, def := range columnDefs { + def = strings.TrimSpace(def) + if def == "" { + continue + } + firstWord := strings.ToUpper(strings.Fields(def)[0]) + if isConstraintKeyword(firstWord) { + continue + } + + field, err := p.parseColumnDef(def, fieldIndex) + if err != nil { + continue + } + if field != nil { + fields[field.Name] = field + fieldIndex++ + } + } + + for _, pkCol := range pkColumns { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + } + + return tableName, fields, nil +} + +// ParseAlterTable parses MSSQL ALTER TABLE statements. +func (p *MSSQLParser) ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error { + return parseAlterTableCommon(stmt, tables, p.parseColumnDef) +} + +// ParseComment parses EXEC sp_addextendedproperty to extract column comments. +func (p *MSSQLParser) ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + if !strings.Contains(upper, "SP_ADDEXTENDEDPROPERTY") || + !strings.Contains(upper, "MS_DESCRIPTION") { + return + } + + // Extract quoted string values + var values []string + inQuote := false + var current strings.Builder + for i := 0; i < len(stmt); i++ { + ch := stmt[i] + if ch == '\'' { + if inQuote { + if i+1 < len(stmt) && stmt[i+1] == '\'' { + current.WriteByte('\'') + i++ + continue + } + values = append(values, current.String()) + current.Reset() + inQuote = false + } else { + inQuote = true + } + } else if inQuote { + current.WriteByte(ch) + } + } + + if len(values) < 8 { + return + } + + var ( + comment string + tableName string + columnName string + ) + + for i := 0; i < len(values)-1; i++ { + switch strings.ToUpper(values[i]) { + case "MS_DESCRIPTION": + comment = values[i+1] + case "TABLE": + tableName = values[i+1] + case "COLUMN": + columnName = values[i+1] + } + } + + if tableName != "" && columnName != "" && comment != "" { + if fields, ok := tables[tableName]; ok { + if field, ok := fields[columnName]; ok { + field.Comment = comment + } + } + } +} + +// parseColumnDef parses a single MSSQL column definition string into a TableField. +// It handles MSSQL-specific syntax including bracket-quoted identifiers and +// type parameters like varchar(max). +func (p *MSSQLParser) parseColumnDef(def string, index int) (*gdb.TableField, error) { + tokens := mysqlTokenize(def) + if len(tokens) < 2 { + return nil, fmt.Errorf("invalid column definition: %s", def) + } + + field := &gdb.TableField{ + Index: index, + Name: unquoteIdentifier(tokens[0]), + Null: true, + } + + field.Type = tokens[1] + + rest := "" + if len(tokens) > 2 { + rest = strings.Join(tokens[2:], " ") + } + if !strings.Contains(field.Type, "(") && strings.HasPrefix(strings.TrimSpace(rest), "(") { + end := strings.Index(rest, ")") + if end >= 0 { + field.Type += rest[:end+1] + rest = strings.TrimSpace(rest[end+1:]) + } + } + + p.parseColumnAttributes(field, rest) + + return field, nil +} + +// parseColumnAttributes parses MSSQL column constraint keywords including +// NOT NULL, NULL, PRIMARY KEY, UNIQUE, IDENTITY (auto-increment), and DEFAULT. +func (p *MSSQLParser) parseColumnAttributes(field *gdb.TableField, attrs string) { + words := strings.Fields(attrs) + upperWords := strings.Fields(strings.ToUpper(attrs)) + + for i := 0; i < len(upperWords); i++ { + switch upperWords[i] { + case "NOT": + if i+1 < len(upperWords) && upperWords[i+1] == "NULL" { + field.Null = false + i++ + } + case "NULL": + field.Null = true + case "PRIMARY": + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + field.Key = "PRI" + i++ + } + case "UNIQUE": + if field.Key == "" { + field.Key = "UNI" + } + case "IDENTITY": + field.Extra = "auto_increment" + if i+1 < len(words) && strings.HasPrefix(words[i+1], "(") { + i++ + } + default: + if strings.HasPrefix(upperWords[i], "IDENTITY(") || strings.HasPrefix(upperWords[i], "IDENTITY (") { + field.Extra = "auto_increment" + } + case "DEFAULT": + if i+1 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[i+1:], " ")) + field.Default = defaultVal + if defaultVal != nil { + i++ + } + } + } + } +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql_test.go new file mode 100644 index 000000000..94d02917f --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mssql_test.go @@ -0,0 +1,72 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_MSSQL_CreateTable_Basic(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MSSQLParser{} + sql := ` +CREATE TABLE [dbo].[users] ( + [id] INT IDENTITY(1,1) NOT NULL, + [name] NVARCHAR(100) NOT NULL, + [email] NVARCHAR(200) NULL, + [balance] DECIMAL(18,2) DEFAULT 0, + [created_at] DATETIME2 NOT NULL DEFAULT GETDATE(), + CONSTRAINT [PK_users] PRIMARY KEY CLUSTERED ([id]) +); +EXEC sp_addextendedproperty 'MS_Description', 'User ID', 'SCHEMA', 'dbo', 'TABLE', 'users', 'COLUMN', 'id'; +EXEC sp_addextendedproperty 'MS_Description', 'User name', 'SCHEMA', 'dbo', 'TABLE', 'users', 'COLUMN', 'name'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 5) + + t.Assert(fields["id"].Extra, "auto_increment") + t.Assert(fields["id"].Null, false) + t.Assert(fields["id"].Key, "PRI") + t.Assert(fields["id"].Comment, "User ID") + + t.Assert(fields["name"].Comment, "User name") + t.Assert(fields["name"].Null, false) + + t.Assert(fields["email"].Null, true) + }) +} + +func Test_MSSQL_AlterTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MSSQLParser{} + sql := ` +CREATE TABLE users ( + id INT IDENTITY(1,1) NOT NULL, + name NVARCHAR(100) NOT NULL, + CONSTRAINT PK_users PRIMARY KEY (id) +); +ALTER TABLE users ADD email NVARCHAR(200) NULL; +ALTER TABLE users DROP COLUMN name; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 2) // id, email + _, ok := fields["name"] + t.Assert(ok, false) + t.Assert(fields["email"].Null, true) + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql.go new file mode 100644 index 000000000..73523a35c --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql.go @@ -0,0 +1,199 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" +) + +// MySQLParser implements SQLParser for MySQL/MariaDB/TiDB DDL. +type MySQLParser struct{} + +// ParseCreateTable parses a single MySQL CREATE TABLE statement. +func (p *MySQLParser) ParseCreateTable(stmt string) (string, map[string]*gdb.TableField, error) { + body, trailing, ok := extractBodyAndTrailing(stmt) + if !ok { + return "", nil, nil + } + + parenIdx := strings.Index(stmt, "(") + header := stmt[:parenIdx] + tableName := extractTableName(header) + if tableName == "" { + return "", nil, fmt.Errorf("cannot extract table name from: %s", header) + } + + columnDefs := splitColumns(body) + fields := make(map[string]*gdb.TableField) + pkColumns := findPrimaryKeysFromConstraints(columnDefs) + + fieldIndex := 0 + for _, def := range columnDefs { + def = strings.TrimSpace(def) + if def == "" { + continue + } + firstWord := strings.ToUpper(strings.Fields(def)[0]) + if isConstraintKeyword(firstWord) { + continue + } + + field, err := p.parseColumnDef(def, fieldIndex) + if err != nil { + continue + } + if field != nil { + fields[field.Name] = field + fieldIndex++ + } + } + + for _, pkCol := range pkColumns { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + } + + // Extract inline comments from trailing table options (not used for field generation) + _ = trailing + + return tableName, fields, nil +} + +// ParseAlterTable parses MySQL ALTER TABLE statements. +func (p *MySQLParser) ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error { + return parseAlterTableCommon(stmt, tables, p.parseColumnDef) +} + +// ParseComment handles MySQL-style comments (inline COMMENT keyword is handled in parseColumnDef). +func (p *MySQLParser) ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) { + // MySQL uses inline COMMENT 'xxx' in column definitions, + // which is already handled by parseColumnDef. No separate COMMENT ON statement. +} + +// parseColumnDef parses a single MySQL column definition string into a TableField. +// It extracts the column name, data type (including UNSIGNED modifier), and delegates +// attribute parsing (NULL, DEFAULT, PRIMARY KEY, COMMENT, etc.) to parseColumnAttributes. +func (p *MySQLParser) parseColumnDef(def string, index int) (*gdb.TableField, error) { + tokens := mysqlTokenize(def) + if len(tokens) < 2 { + return nil, fmt.Errorf("invalid column definition: %s", def) + } + + field := &gdb.TableField{ + Index: index, + Name: unquoteIdentifier(tokens[0]), + Null: true, + } + + typeStr := tokens[1] + rest := "" + if len(tokens) > 2 { + rest = strings.Join(tokens[2:], " ") + } + + // Check if rest starts with '(' meaning the type params are in rest + if !strings.Contains(typeStr, "(") && strings.HasPrefix(strings.TrimSpace(rest), "(") { + endParen := strings.Index(rest, ")") + if endParen >= 0 { + typeStr += rest[:endParen+1] + rest = strings.TrimSpace(rest[endParen+1:]) + } + } + + field.Type = typeStr + + // Handle UNSIGNED + upperRest := strings.ToUpper(rest) + if strings.HasPrefix(upperRest, "UNSIGNED") { + field.Type += " unsigned" + rest = strings.TrimSpace(rest[8:]) + } + + p.parseColumnAttributes(field, rest) + + return field, nil +} + +// parseColumnAttributes parses MySQL column constraint keywords from the attribute string +// following the column type. It handles NOT NULL, NULL, PRIMARY KEY, UNIQUE, AUTO_INCREMENT, +// DEFAULT, COMMENT, and ON UPDATE clauses. +func (p *MySQLParser) parseColumnAttributes(field *gdb.TableField, attrs string) { + words := strings.Fields(attrs) + upperWords := strings.Fields(strings.ToUpper(attrs)) + + for i := 0; i < len(upperWords); i++ { + switch upperWords[i] { + case "NOT": + if i+1 < len(upperWords) && upperWords[i+1] == "NULL" { + field.Null = false + i++ + } + case "NULL": + field.Null = true + case "PRIMARY": + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + field.Key = "PRI" + i++ + } + case "UNIQUE": + if field.Key == "" { + field.Key = "UNI" + } + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + i++ + } + case "KEY": + if field.Key == "" { + field.Key = "MUL" + } + case "AUTO_INCREMENT": + field.Extra = "auto_increment" + case "DEFAULT": + if i+1 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[i+1:], " ")) + field.Default = defaultVal + if defaultVal != nil { + if strings.HasPrefix(words[i+1], "'") { + for j := i + 1; j < len(words); j++ { + if strings.HasSuffix(words[j], "'") { + i = j + break + } + } + } else { + i++ + } + } + } + case "COMMENT": + if i+1 < len(words) { + comment := strings.Join(words[i+1:], " ") + comment = strings.TrimSpace(comment) + if len(comment) >= 2 && comment[0] == '\'' && comment[len(comment)-1] == '\'' { + comment = comment[1 : len(comment)-1] + comment = strings.ReplaceAll(comment, "''", "'") + } + field.Comment = comment + return + } + case "ON": + if i+1 < len(upperWords) && upperWords[i+1] == "UPDATE" { + if i+2 < len(upperWords) { + if field.Extra != "" { + field.Extra += ", " + } + field.Extra += "on update " + strings.ToLower(words[i+2]) + i += 2 + } + } + } + } +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql_test.go new file mode 100644 index 000000000..b8f72ab0b --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_mysql_test.go @@ -0,0 +1,300 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_MySQL_CreateTable_Basic(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 'User ID', + name VARCHAR(100) NOT NULL DEFAULT '' COMMENT 'User name', + email VARCHAR(200) NULL COMMENT 'Email address', + age INT(11) DEFAULT 0, + score DECIMAL(10,2) DEFAULT 0.00, + status TINYINT(1) NOT NULL DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NULL ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (id), + UNIQUE KEY uk_email (email) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='User table'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 1) + + fields := tables["users"] + t.Assert(len(fields), 8) + + // Check id field + t.Assert(fields["id"].Name, "id") + t.Assert(fields["id"].Type, "BIGINT(20) unsigned") + t.Assert(fields["id"].Null, false) + t.Assert(fields["id"].Key, "PRI") + t.Assert(fields["id"].Extra, "auto_increment") + t.Assert(fields["id"].Comment, "User ID") + t.Assert(fields["id"].Index, 0) + + // Check name field + t.Assert(fields["name"].Name, "name") + t.Assert(fields["name"].Null, false) + t.Assert(fields["name"].Comment, "User name") + + // Check email field + t.Assert(fields["email"].Null, true) + + // Check created_at + t.Assert(fields["created_at"].Null, false) + }) +} + +func Test_MySQL_AlterTable_AddColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) NOT NULL, + PRIMARY KEY (id) +); +ALTER TABLE users ADD COLUMN email VARCHAR(200) NULL COMMENT 'Email'; +ALTER TABLE users ADD COLUMN age INT DEFAULT 0; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 4) + t.Assert(fields["email"].Name, "email") + t.Assert(fields["email"].Null, true) + t.Assert(fields["email"].Comment, "Email") + t.Assert(fields["age"].Name, "age") + }) +} + +func Test_MySQL_AlterTable_DropColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100), + old_field VARCHAR(50), + PRIMARY KEY (id) +); +ALTER TABLE users DROP COLUMN old_field; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 2) + _, ok := fields["old_field"] + t.Assert(ok, false) + }) +} + +func Test_MySQL_AlterTable_ModifyColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100), + PRIMARY KEY (id) +); +ALTER TABLE users MODIFY COLUMN name VARCHAR(200) NOT NULL COMMENT 'Full name'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(fields["name"].Type, "VARCHAR(200)") + t.Assert(fields["name"].Null, false) + t.Assert(fields["name"].Comment, "Full name") + }) +} + +func Test_MySQL_AlterTable_ChangeColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + old_name VARCHAR(100), + PRIMARY KEY (id) +); +ALTER TABLE users CHANGE COLUMN old_name new_name VARCHAR(200) NOT NULL; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + _, ok := fields["old_name"] + t.Assert(ok, false) + t.Assert(fields["new_name"].Name, "new_name") + t.Assert(fields["new_name"].Type, "VARCHAR(200)") + }) +} + +func Test_MySQL_AlterTable_AddPrimaryKey(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users ( + id INT NOT NULL, + name VARCHAR(100) +); +ALTER TABLE users ADD PRIMARY KEY (id); +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + t.Assert(tables["users"]["id"].Key, "PRI") + }) +} + +func Test_MySQL_DropTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE temp_log (id INT, msg TEXT); +CREATE TABLE users (id INT, name VARCHAR(100)); +DROP TABLE IF EXISTS temp_log; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + t.Assert(len(tables), 1) + _, ok := tables["temp_log"] + t.Assert(ok, false) + _, ok = tables["users"] + t.Assert(ok, true) + }) +} + +func Test_MySQL_MultipleMigrations(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + + // Simulate V1: initial schema + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, ` +CREATE TABLE users ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + name VARCHAR(50) NOT NULL, + PRIMARY KEY (id) +); +`, tables) + t.AssertNil(err) + + // Simulate V2: add columns + err = processSQL(parser, ` +ALTER TABLE users ADD COLUMN email VARCHAR(200) NULL; +ALTER TABLE users ADD COLUMN phone VARCHAR(20) NULL; +`, tables) + t.AssertNil(err) + + // Simulate V3: modify + drop + err = processSQL(parser, ` +ALTER TABLE users MODIFY COLUMN name VARCHAR(100) NOT NULL COMMENT 'Full name'; +ALTER TABLE users DROP COLUMN phone; +`, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 3) // id, name, email + t.Assert(fields["name"].Type, "VARCHAR(100)") + t.Assert(fields["name"].Comment, "Full name") + _, ok := fields["phone"] + t.Assert(ok, false) + t.Assert(fields["email"].Null, true) + }) +} + +func Test_MySQL_FullMigrationScenario(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + tables := make(map[string]map[string]*gdb.TableField) + + // V001: Initial tables + err := processSQL(parser, ` +CREATE TABLE IF NOT EXISTS users ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 'Primary key', + username VARCHAR(50) NOT NULL COMMENT 'Username', + password VARCHAR(128) NOT NULL COMMENT 'Hashed password', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id), + UNIQUE KEY uk_username (username) +); + +CREATE TABLE IF NOT EXISTS orders ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + user_id BIGINT UNSIGNED NOT NULL, + amount DECIMAL(10,2) NOT NULL DEFAULT 0.00, + PRIMARY KEY (id) +); +`, tables) + t.AssertNil(err) + t.Assert(len(tables), 2) + + // V002: Add email, phone + err = processSQL(parser, ` +ALTER TABLE users ADD COLUMN email VARCHAR(200) NULL COMMENT 'User email'; +ALTER TABLE users ADD COLUMN phone VARCHAR(20) NULL COMMENT 'Phone number'; +`, tables) + t.AssertNil(err) + t.Assert(len(tables["users"]), 6) + + // V003: Modify, rename, drop + err = processSQL(parser, ` +ALTER TABLE users MODIFY COLUMN username VARCHAR(100) NOT NULL COMMENT 'Login name'; +ALTER TABLE users CHANGE COLUMN phone mobile VARCHAR(20) NULL COMMENT 'Mobile number'; +ALTER TABLE users DROP COLUMN password; +ALTER TABLE orders ADD COLUMN status TINYINT(1) NOT NULL DEFAULT 0 COMMENT 'Order status'; +`, tables) + t.AssertNil(err) + + userFields := tables["users"] + t.Assert(len(userFields), 5) // id, username, email, mobile, created_at + t.Assert(userFields["username"].Type, "VARCHAR(100)") + t.Assert(userFields["username"].Comment, "Login name") + _, ok := userFields["password"] + t.Assert(ok, false) + _, ok = userFields["phone"] + t.Assert(ok, false) + t.Assert(userFields["mobile"].Name, "mobile") + t.Assert(userFields["mobile"].Comment, "Mobile number") + + orderFields := tables["orders"] + t.Assert(len(orderFields), 4) + t.Assert(orderFields["status"].Default, "0") + + // V004: Drop table + err = processSQL(parser, ` +DROP TABLE IF EXISTS orders; +`, tables) + t.AssertNil(err) + t.Assert(len(tables), 1) + _, ok = tables["orders"] + t.Assert(ok, false) + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle.go new file mode 100644 index 000000000..ee4d25647 --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle.go @@ -0,0 +1,209 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" +) + +// OracleParser implements SQLParser for Oracle/DM DDL. +type OracleParser struct{} + +// ParseCreateTable parses a single Oracle CREATE TABLE statement. +func (p *OracleParser) ParseCreateTable(stmt string) (string, map[string]*gdb.TableField, error) { + body, _, ok := extractBodyAndTrailing(stmt) + if !ok { + return "", nil, nil + } + + parenIdx := strings.Index(stmt, "(") + header := stmt[:parenIdx] + tableName := extractTableName(header) + if tableName == "" { + return "", nil, fmt.Errorf("cannot extract table name from: %s", header) + } + + columnDefs := splitColumns(body) + fields := make(map[string]*gdb.TableField) + pkColumns := findPrimaryKeysFromConstraints(columnDefs) + + fieldIndex := 0 + for _, def := range columnDefs { + def = strings.TrimSpace(def) + if def == "" { + continue + } + firstWord := strings.ToUpper(strings.Fields(def)[0]) + if isConstraintKeyword(firstWord) { + continue + } + + field, err := p.parseColumnDef(def, fieldIndex) + if err != nil { + continue + } + if field != nil { + fields[field.Name] = field + fieldIndex++ + } + } + + for _, pkCol := range pkColumns { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + upperPk := strings.ToUpper(pkCol) + if f, ok := fields[upperPk]; ok { + f.Key = "PRI" + } + } + + return tableName, fields, nil +} + +// ParseAlterTable parses Oracle ALTER TABLE statements. +func (p *OracleParser) ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error { + return parseAlterTableCommon(stmt, tables, p.parseColumnDef) +} + +// ParseComment parses COMMENT ON COLUMN table.column IS 'comment'. +func (p *OracleParser) ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + if !strings.HasPrefix(upper, "COMMENT ON COLUMN") { + return + } + + rest := strings.TrimSpace(stmt[len("COMMENT ON COLUMN"):]) + isIdx := strings.Index(strings.ToUpper(rest), " IS ") + if isIdx < 0 { + return + } + ref := strings.TrimSpace(rest[:isIdx]) + comment := strings.TrimSpace(rest[isIdx+4:]) + + if len(comment) >= 2 && comment[0] == '\'' && comment[len(comment)-1] == '\'' { + comment = comment[1 : len(comment)-1] + comment = strings.ReplaceAll(comment, "''", "'") + } + + parts := strings.Split(ref, ".") + var tableName, columnName string + switch len(parts) { + case 2: + tableName = unquoteIdentifier(parts[0]) + columnName = unquoteIdentifier(parts[1]) + case 3: + tableName = unquoteIdentifier(parts[1]) + columnName = unquoteIdentifier(parts[2]) + default: + return + } + + if fields, ok := tables[tableName]; ok { + if field, ok := fields[columnName]; ok { + field.Comment = comment + } + } +} + +// parseColumnDef parses a single Oracle column definition string into a TableField. +// It handles Oracle-specific types including TIMESTAMP WITH TIME ZONE and +// TIMESTAMP WITH LOCAL TIME ZONE. +func (p *OracleParser) parseColumnDef(def string, index int) (*gdb.TableField, error) { + tokens := mysqlTokenize(def) + if len(tokens) < 2 { + return nil, fmt.Errorf("invalid column definition: %s", def) + } + + field := &gdb.TableField{ + Index: index, + Name: unquoteIdentifier(tokens[0]), + Null: true, + } + + field.Type = tokens[1] + + rest := "" + if len(tokens) > 2 { + rest = strings.Join(tokens[2:], " ") + } + + if !strings.Contains(field.Type, "(") && strings.HasPrefix(strings.TrimSpace(rest), "(") { + end := strings.Index(rest, ")") + if end >= 0 { + field.Type += rest[:end+1] + rest = strings.TrimSpace(rest[end+1:]) + } + } + + // Handle TIMESTAMP WITH TIME ZONE / WITH LOCAL TIME ZONE + upperType := strings.ToUpper(field.Type) + upperRest := strings.ToUpper(rest) + if upperType == "TIMESTAMP" { + if strings.HasPrefix(upperRest, "WITH LOCAL TIME ZONE") { + field.Type = "timestamp with local time zone" + rest = strings.TrimSpace(rest[len("WITH LOCAL TIME ZONE"):]) + } else if strings.HasPrefix(upperRest, "WITH TIME ZONE") { + field.Type = "timestamp with time zone" + rest = strings.TrimSpace(rest[len("WITH TIME ZONE"):]) + } + } + + p.parseColumnAttributes(field, rest) + + return field, nil +} + +// parseColumnAttributes parses Oracle column constraint keywords including +// NOT NULL, NULL, PRIMARY KEY, UNIQUE, DEFAULT, and GENERATED ... AS IDENTITY. +func (p *OracleParser) parseColumnAttributes(field *gdb.TableField, attrs string) { + words := strings.Fields(attrs) + upperWords := strings.Fields(strings.ToUpper(attrs)) + + for i := 0; i < len(upperWords); i++ { + switch upperWords[i] { + case "NOT": + if i+1 < len(upperWords) && upperWords[i+1] == "NULL" { + field.Null = false + i++ + } + case "NULL": + field.Null = true + case "PRIMARY": + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + field.Key = "PRI" + i++ + } + case "UNIQUE": + if field.Key == "" { + field.Key = "UNI" + } + case "DEFAULT": + if i+1 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[i+1:], " ")) + field.Default = defaultVal + if defaultVal != nil { + i++ + } + } + case "GENERATED": + rest := strings.Join(upperWords[i:], " ") + if strings.Contains(rest, "AS IDENTITY") { + field.Extra = "auto_increment" + for j := i + 1; j < len(upperWords); j++ { + if upperWords[j] == "IDENTITY" { + i = j + break + } + } + } + } + } +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle_test.go new file mode 100644 index 000000000..fe3db006a --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_oracle_test.go @@ -0,0 +1,97 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_Oracle_CreateTable_Basic(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &OracleParser{} + sql := ` +CREATE TABLE users ( + ID NUMBER(10) NOT NULL, + NAME VARCHAR2(100) NOT NULL, + EMAIL VARCHAR2(200), + CREATED_AT TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP, + CONSTRAINT PK_USERS PRIMARY KEY (ID) +); +COMMENT ON COLUMN users.ID IS 'User ID'; +COMMENT ON COLUMN users.NAME IS 'User name'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 4) + + t.Assert(fields["ID"].Key, "PRI") + t.Assert(fields["ID"].Null, false) + t.Assert(fields["ID"].Comment, "User ID") + + t.Assert(fields["NAME"].Null, false) + t.Assert(fields["NAME"].Comment, "User name") + + t.Assert(fields["CREATED_AT"].Type, "timestamp with time zone") + }) +} + +func Test_Oracle_AlterTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &OracleParser{} + sql := ` +CREATE TABLE users ( + ID NUMBER(10) NOT NULL, + NAME VARCHAR2(100), + CONSTRAINT PK_USERS PRIMARY KEY (ID) +); +ALTER TABLE users ADD EMAIL VARCHAR2(200); +ALTER TABLE users MODIFY NAME VARCHAR2(200) NOT NULL; +COMMENT ON COLUMN users.EMAIL IS 'Email address'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 3) + t.Assert(fields["EMAIL"].Comment, "Email address") + t.Assert(fields["NAME"].Type, "VARCHAR2(200)") + t.Assert(fields["NAME"].Null, false) + }) +} + +func Test_Oracle_AlterTable_DropColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &OracleParser{} + sql := ` +CREATE TABLE users ( + ID NUMBER(10) NOT NULL, + NAME VARCHAR2(100) NOT NULL, + OLD_COL VARCHAR2(50), + EMAIL VARCHAR2(200), + CONSTRAINT PK_USERS PRIMARY KEY (ID) +); +ALTER TABLE users DROP COLUMN OLD_COL; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 3) + _, ok := fields["OLD_COL"] + t.Assert(ok, false) + t.Assert(fields["NAME"].Name, "NAME") + t.Assert(fields["EMAIL"].Name, "EMAIL") + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql.go new file mode 100644 index 000000000..5f127e208 --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql.go @@ -0,0 +1,268 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" +) + +// PgSQLParser implements SQLParser for PostgreSQL DDL. +type PgSQLParser struct{} + +// ParseCreateTable parses a single PostgreSQL CREATE TABLE statement. +func (p *PgSQLParser) ParseCreateTable(stmt string) (string, map[string]*gdb.TableField, error) { + body, _, ok := extractBodyAndTrailing(stmt) + if !ok { + return "", nil, nil + } + + parenIdx := strings.Index(stmt, "(") + header := stmt[:parenIdx] + tableName := extractTableName(header) + if tableName == "" { + return "", nil, fmt.Errorf("cannot extract table name from: %s", header) + } + + columnDefs := splitColumns(body) + fields := make(map[string]*gdb.TableField) + pkColumns := findPrimaryKeysFromConstraints(columnDefs) + + fieldIndex := 0 + for _, def := range columnDefs { + def = strings.TrimSpace(def) + if def == "" { + continue + } + firstWord := strings.ToUpper(strings.Fields(def)[0]) + if isConstraintKeyword(firstWord) { + continue + } + + field, err := p.parseColumnDef(def, fieldIndex) + if err != nil { + continue + } + if field != nil { + fields[field.Name] = field + fieldIndex++ + } + } + + for _, pkCol := range pkColumns { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + } + + return tableName, fields, nil +} + +// ParseAlterTable parses PostgreSQL ALTER TABLE statements. +func (p *PgSQLParser) ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error { + return parseAlterTableCommon(stmt, tables, p.parseColumnDef) +} + +// ParseComment parses COMMENT ON COLUMN schema.table.column IS 'comment' statements. +func (p *PgSQLParser) ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) { + upper := strings.ToUpper(strings.TrimSpace(stmt)) + if !strings.HasPrefix(upper, "COMMENT ON COLUMN") { + return + } + + rest := strings.TrimSpace(stmt[len("COMMENT ON COLUMN"):]) + isIdx := strings.Index(strings.ToUpper(rest), " IS ") + if isIdx < 0 { + return + } + ref := strings.TrimSpace(rest[:isIdx]) + comment := strings.TrimSpace(rest[isIdx+4:]) + + if len(comment) >= 2 && comment[0] == '\'' && comment[len(comment)-1] == '\'' { + comment = comment[1 : len(comment)-1] + comment = strings.ReplaceAll(comment, "''", "'") + } + + parts := strings.Split(ref, ".") + var tableName, columnName string + switch len(parts) { + case 2: + tableName = unquoteIdentifier(parts[0]) + columnName = unquoteIdentifier(parts[1]) + case 3: + tableName = unquoteIdentifier(parts[1]) + columnName = unquoteIdentifier(parts[2]) + default: + return + } + + if fields, ok := tables[tableName]; ok { + if field, ok := fields[columnName]; ok { + field.Comment = comment + } + } +} + +// parseColumnDef parses a single PostgreSQL column definition string into a TableField. +// It handles PostgreSQL-specific types like SERIAL/BIGSERIAL (auto-increment shorthand), +// CHARACTER VARYING, DOUBLE PRECISION, TIMESTAMP WITH TIME ZONE, and array types. +func (p *PgSQLParser) parseColumnDef(def string, index int) (*gdb.TableField, error) { + tokens := mysqlTokenize(def) + if len(tokens) < 2 { + return nil, fmt.Errorf("invalid column definition: %s", def) + } + + field := &gdb.TableField{ + Index: index, + Name: unquoteIdentifier(tokens[0]), + Null: true, + } + + // Handle SERIAL types + typeToken := strings.ToUpper(tokens[1]) + switch typeToken { + case "SERIAL": + field.Type = "int" + field.Extra = "auto_increment" + field.Null = false + case "BIGSERIAL": + field.Type = "bigint" + field.Extra = "auto_increment" + field.Null = false + case "SMALLSERIAL": + field.Type = "smallint" + field.Extra = "auto_increment" + field.Null = false + default: + field.Type = tokens[1] + } + + rest := "" + if len(tokens) > 2 { + rest = strings.Join(tokens[2:], " ") + } + upperType := strings.ToUpper(field.Type) + upperRest := strings.ToUpper(rest) + + switch { + case upperType == "CHARACTER" && strings.HasPrefix(upperRest, "VARYING"): + rest = strings.TrimSpace(rest[len("VARYING"):]) + if strings.HasPrefix(rest, "(") { + end := strings.Index(rest, ")") + if end >= 0 { + field.Type = "character varying" + rest[:end+1] + rest = strings.TrimSpace(rest[end+1:]) + } + } else { + field.Type = "character varying" + } + case upperType == "DOUBLE" && strings.HasPrefix(upperRest, "PRECISION"): + field.Type = "double precision" + rest = strings.TrimSpace(rest[len("PRECISION"):]) + case (upperType == "TIMESTAMP" || upperType == "TIME") && + (strings.HasPrefix(upperRest, "WITH TIME ZONE") || strings.HasPrefix(upperRest, "WITHOUT TIME ZONE")): + if strings.HasPrefix(upperRest, "WITH TIME ZONE") { + if upperType == "TIMESTAMP" { + field.Type = "timestamptz" + } else { + field.Type = "time with time zone" + } + rest = strings.TrimSpace(rest[len("WITH TIME ZONE"):]) + } else { + field.Type = strings.ToLower(upperType) + rest = strings.TrimSpace(rest[len("WITHOUT TIME ZONE"):]) + } + case !strings.Contains(field.Type, "(") && strings.HasPrefix(strings.TrimSpace(rest), "("): + end := strings.Index(rest, ")") + if end >= 0 { + field.Type += rest[:end+1] + rest = strings.TrimSpace(rest[end+1:]) + } + } + + // Handle array types + if strings.HasPrefix(rest, "[]") { + field.Type += "[]" + rest = strings.TrimSpace(rest[2:]) + } else if strings.HasPrefix(strings.ToUpper(rest), "ARRAY") { + field.Type += "[]" + rest = strings.TrimSpace(rest[5:]) + } + + p.parseColumnAttributes(field, rest) + + return field, nil +} + +// parseColumnAttributes parses PostgreSQL column constraint keywords including +// NOT NULL, NULL, PRIMARY KEY, UNIQUE, DEFAULT, GENERATED ... AS IDENTITY, and REFERENCES. +func (p *PgSQLParser) parseColumnAttributes(field *gdb.TableField, attrs string) { + words := strings.Fields(attrs) + upperWords := strings.Fields(strings.ToUpper(attrs)) + + for i := 0; i < len(upperWords); i++ { + switch upperWords[i] { + case "NOT": + if i+1 < len(upperWords) && upperWords[i+1] == "NULL" { + field.Null = false + i++ + } + case "NULL": + field.Null = true + case "PRIMARY": + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + field.Key = "PRI" + i++ + } + case "UNIQUE": + if field.Key == "" { + field.Key = "UNI" + } + case "DEFAULT": + if i+1 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[i+1:], " ")) + field.Default = defaultVal + if defaultVal != nil { + i++ + } + } + case "GENERATED": + if containsSequence(upperWords[i:], "ALWAYS", "AS", "IDENTITY") || + containsSequence(upperWords[i:], "BY", "DEFAULT", "AS", "IDENTITY") { + field.Extra = "auto_increment" + for j := i + 1; j < len(upperWords); j++ { + if upperWords[j] == "IDENTITY" { + i = j + break + } + } + } + case "REFERENCES": + for j := i + 1; j < len(upperWords); j++ { + i = j + if strings.Contains(words[j], ")") { + break + } + } + } + } +} + +// containsSequence checks if words slice contains the given word sequence starting from index 1. +func containsSequence(words []string, seq ...string) bool { + if len(words) < len(seq)+1 { + return false + } + for i, s := range seq { + if words[i+1] != s { + return false + } + } + return true +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql_test.go new file mode 100644 index 000000000..f40982e63 --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_pgsql_test.go @@ -0,0 +1,232 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_PgSQL_CreateTable_Basic(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := ` +CREATE TABLE users ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email CHARACTER VARYING(200), + score DOUBLE PRECISION DEFAULT 0.0, + metadata JSONB, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + is_active BOOLEAN NOT NULL DEFAULT TRUE +); +COMMENT ON COLUMN users.name IS 'User full name'; +COMMENT ON COLUMN users.email IS 'Email address'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 7) + + // BIGSERIAL should be auto_increment bigint + t.Assert(fields["id"].Type, "bigint") + t.Assert(fields["id"].Extra, "auto_increment") + t.Assert(fields["id"].Key, "PRI") + + // CHARACTER VARYING + t.AssertNE(fields["email"], nil) + + // DOUBLE PRECISION + t.Assert(fields["score"].Type, "double precision") + + // JSONB + t.Assert(fields["metadata"].Type, "JSONB") + + // TIMESTAMP WITH TIME ZONE + t.Assert(fields["created_at"].Type, "timestamptz") + + // COMMENT ON COLUMN + t.Assert(fields["name"].Comment, "User full name") + t.Assert(fields["email"].Comment, "Email address") + }) +} + +func Test_PgSQL_AlterTable_AddColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := ` +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL +); +ALTER TABLE users ADD COLUMN email VARCHAR(200); +COMMENT ON COLUMN users.email IS 'User email'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 3) + t.Assert(fields["email"].Name, "email") + t.Assert(fields["email"].Comment, "User email") + }) +} + +func Test_PgSQL_AlterTable_AlterColumnType(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := ` +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) +); +ALTER TABLE users ALTER COLUMN name TYPE VARCHAR(200); +ALTER TABLE users ALTER COLUMN name SET NOT NULL; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(fields["name"].Type, "VARCHAR(200)") + t.Assert(fields["name"].Null, false) + }) +} + +func Test_PgSQL_AlterTable_DropColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := ` +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name VARCHAR(100), + old_col TEXT +); +ALTER TABLE users DROP COLUMN old_col; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 2) + _, ok := fields["old_col"] + t.Assert(ok, false) + }) +} + +func Test_PgSQL_AlterTable_RenameColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := ` +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + old_name VARCHAR(100) +); +ALTER TABLE users RENAME COLUMN old_name TO new_name; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + _, ok := fields["old_name"] + t.Assert(ok, false) + t.Assert(fields["new_name"].Name, "new_name") + }) +} + +func Test_PgSQL_MultipleMigrations(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + tables := make(map[string]map[string]*gdb.TableField) + + // V1 + err := processSQL(parser, ` +CREATE TABLE products ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + price NUMERIC(10,2) DEFAULT 0.00 +); +`, tables) + t.AssertNil(err) + + // V2: add, alter, comment + err = processSQL(parser, ` +ALTER TABLE products ADD COLUMN category VARCHAR(50); +ALTER TABLE products ALTER COLUMN name TYPE VARCHAR(200); +ALTER TABLE products ALTER COLUMN name SET NOT NULL; +COMMENT ON COLUMN products.category IS 'Product category'; +`, tables) + t.AssertNil(err) + + // V3: rename, drop + err = processSQL(parser, ` +ALTER TABLE products RENAME COLUMN category TO product_category; +`, tables) + t.AssertNil(err) + + fields := tables["products"] + t.Assert(len(fields), 4) + t.Assert(fields["name"].Type, "VARCHAR(200)") + t.Assert(fields["name"].Null, false) + _, ok := fields["category"] + t.Assert(ok, false) + t.Assert(fields["product_category"].Name, "product_category") + t.Assert(fields["product_category"].Comment, "Product category") + }) +} + +func Test_PgSQL_FullMigrationScenario(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + tables := make(map[string]map[string]*gdb.TableField) + + // V001: Initial + err := processSQL(parser, ` +CREATE TABLE users ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(200) UNIQUE +); +COMMENT ON COLUMN users.name IS 'User name'; +`, tables) + t.AssertNil(err) + + // V002: Add, alter type, set not null + err = processSQL(parser, ` +ALTER TABLE users ADD COLUMN avatar TEXT; +ALTER TABLE users ALTER COLUMN name TYPE VARCHAR(200); +ALTER TABLE users ALTER COLUMN email SET NOT NULL; +COMMENT ON COLUMN users.avatar IS 'Avatar URL'; +`, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 4) + t.Assert(fields["name"].Type, "VARCHAR(200)") + t.Assert(fields["email"].Null, false) + t.Assert(fields["avatar"].Comment, "Avatar URL") + + // V003: Rename column, drop not null + err = processSQL(parser, ` +ALTER TABLE users RENAME COLUMN avatar TO profile_image; +ALTER TABLE users ALTER COLUMN email DROP NOT NULL; +`, tables) + t.AssertNil(err) + + _, ok := fields["avatar"] + t.Assert(ok, false) + t.Assert(fields["profile_image"].Name, "profile_image") + t.Assert(fields["email"].Null, true) + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite.go new file mode 100644 index 000000000..462f382e4 --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite.go @@ -0,0 +1,159 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" +) + +// SQLiteParser implements SQLParser for SQLite DDL. +type SQLiteParser struct{} + +// ParseCreateTable parses a single SQLite CREATE TABLE statement. +func (p *SQLiteParser) ParseCreateTable(stmt string) (string, map[string]*gdb.TableField, error) { + body, _, ok := extractBodyAndTrailing(stmt) + if !ok { + return "", nil, nil + } + + parenIdx := strings.Index(stmt, "(") + header := stmt[:parenIdx] + tableName := extractTableName(header) + if tableName == "" { + return "", nil, fmt.Errorf("cannot extract table name from: %s", header) + } + + columnDefs := splitColumns(body) + fields := make(map[string]*gdb.TableField) + pkColumns := findPrimaryKeysFromConstraints(columnDefs) + + fieldIndex := 0 + for _, def := range columnDefs { + def = strings.TrimSpace(def) + if def == "" { + continue + } + firstWord := strings.ToUpper(strings.Fields(def)[0]) + if isConstraintKeyword(firstWord) { + continue + } + + field, err := p.parseColumnDef(def, fieldIndex) + if err != nil { + continue + } + if field != nil { + fields[field.Name] = field + fieldIndex++ + } + } + + for _, pkCol := range pkColumns { + if f, ok := fields[pkCol]; ok { + f.Key = "PRI" + } + } + + return tableName, fields, nil +} + +// ParseAlterTable parses SQLite ALTER TABLE statements. +// Note: SQLite only supports ADD COLUMN and RENAME COLUMN in ALTER TABLE. +func (p *SQLiteParser) ParseAlterTable(stmt string, tables map[string]map[string]*gdb.TableField) error { + return parseAlterTableCommon(stmt, tables, p.parseColumnDef) +} + +// ParseComment is a no-op for SQLite as it doesn't support COMMENT ON statements. +func (p *SQLiteParser) ParseComment(stmt string, tables map[string]map[string]*gdb.TableField) { + // SQLite does not support comments on columns. +} + +// parseColumnDef parses a single SQLite column definition string into a TableField. +// SQLite has flexible typing (type affinity), so columns may have no explicit type, +// in which case "text" is used as the default type. +func (p *SQLiteParser) parseColumnDef(def string, index int) (*gdb.TableField, error) { + tokens := mysqlTokenize(def) + if len(tokens) < 1 { + return nil, fmt.Errorf("invalid column definition: %s", def) + } + + field := &gdb.TableField{ + Index: index, + Name: unquoteIdentifier(tokens[0]), + Null: true, + } + + if len(tokens) < 2 { + field.Type = "text" + return field, nil + } + + field.Type = tokens[1] + + rest := "" + if len(tokens) > 2 { + rest = strings.Join(tokens[2:], " ") + } + + if !strings.Contains(field.Type, "(") && strings.HasPrefix(strings.TrimSpace(rest), "(") { + end := strings.Index(rest, ")") + if end >= 0 { + field.Type += rest[:end+1] + rest = strings.TrimSpace(rest[end+1:]) + } + } + + p.parseColumnAttributes(field, rest) + + return field, nil +} + +// parseColumnAttributes parses SQLite column constraint keywords including +// NOT NULL, NULL, PRIMARY KEY (with optional AUTOINCREMENT), UNIQUE, and DEFAULT. +func (p *SQLiteParser) parseColumnAttributes(field *gdb.TableField, attrs string) { + words := strings.Fields(attrs) + upperWords := strings.Fields(strings.ToUpper(attrs)) + + for i := 0; i < len(upperWords); i++ { + switch upperWords[i] { + case "NOT": + if i+1 < len(upperWords) && upperWords[i+1] == "NULL" { + field.Null = false + i++ + } + case "NULL": + field.Null = true + case "PRIMARY": + if i+1 < len(upperWords) && upperWords[i+1] == "KEY" { + field.Key = "PRI" + field.Null = false + i++ + if i+1 < len(upperWords) && upperWords[i+1] == "AUTOINCREMENT" { + field.Extra = "auto_increment" + i++ + } + } + case "AUTOINCREMENT": + field.Extra = "auto_increment" + case "UNIQUE": + if field.Key == "" { + field.Key = "UNI" + } + case "DEFAULT": + if i+1 < len(words) { + defaultVal, _ := extractDefaultValue("DEFAULT " + strings.Join(words[i+1:], " ")) + field.Default = defaultVal + if defaultVal != nil { + i++ + } + } + } + } +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite_test.go new file mode 100644 index 000000000..eede7ce2c --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_sqlite_test.go @@ -0,0 +1,112 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_SQLite_CreateTable_Basic(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &SQLiteParser{} + sql := ` +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + email TEXT, + age INTEGER DEFAULT 0, + score REAL DEFAULT 0.0, + is_active BOOLEAN NOT NULL DEFAULT 1 +); +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 6) + + t.Assert(fields["id"].Key, "PRI") + t.Assert(fields["id"].Extra, "auto_increment") + t.Assert(fields["id"].Null, false) + + t.Assert(fields["name"].Null, false) + t.Assert(fields["email"].Null, true) + t.Assert(fields["age"].Default, "0") + }) +} + +func Test_SQLite_AlterTable_AddColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &SQLiteParser{} + sql := ` +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL +); +ALTER TABLE users ADD COLUMN email TEXT; +ALTER TABLE users ADD COLUMN phone TEXT DEFAULT ''; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 4) + t.Assert(fields["email"].Name, "email") + t.Assert(fields["phone"].Name, "phone") + }) +} + +func Test_SQLite_AlterTable_DropColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &SQLiteParser{} + sql := ` +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + old_col TEXT, + email TEXT +); +ALTER TABLE users DROP COLUMN old_col; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + t.Assert(len(fields), 3) + _, ok := fields["old_col"] + t.Assert(ok, false) + t.Assert(fields["name"].Name, "name") + t.Assert(fields["email"].Name, "email") + }) +} + +func Test_SQLite_RenameColumn(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &SQLiteParser{} + sql := ` +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + old_name TEXT NOT NULL +); +ALTER TABLE users RENAME COLUMN old_name TO new_name; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + + fields := tables["users"] + _, ok := fields["old_name"] + t.Assert(ok, false) + t.Assert(fields["new_name"].Name, "new_name") + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_sql_parser_test.go b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_test.go new file mode 100644 index 000000000..254df3f5c --- /dev/null +++ b/cmd/gf/internal/cmd/gendao/gendao_sql_parser_test.go @@ -0,0 +1,302 @@ +// Copyright GoFrame gf Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gendao + +import ( + "testing" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/test/gtest" +) + +// =========================== +// Common parser utilities tests +// =========================== + +func Test_splitSQLStatements(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + stmts := splitSQLStatements("CREATE TABLE t1 (id INT); ALTER TABLE t1 ADD COLUMN name VARCHAR(100);") + t.Assert(len(stmts), 2) + t.AssertIN("CREATE TABLE t1 (id INT)", stmts) + }) +} + +func Test_splitSQLStatements_WithComments(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + sql := ` +-- This is a comment +CREATE TABLE t1 (id INT); +/* Block comment */ +ALTER TABLE t1 ADD COLUMN name VARCHAR(100); +` + stmts := splitSQLStatements(sql) + t.Assert(len(stmts), 2) + }) +} + +func Test_splitSQLStatements_WithQuotedSemicolon(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + sql := `CREATE TABLE t1 (id INT, name VARCHAR(100) DEFAULT 'a;b');` + stmts := splitSQLStatements(sql) + t.Assert(len(stmts), 1) + }) +} + +func Test_classifyStatement(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + t.Assert(classifyStatement("CREATE TABLE users (id INT)"), SQLStatementCreateTable) + t.Assert(classifyStatement("CREATE TEMPORARY TABLE tmp (id INT)"), SQLStatementCreateTable) + t.Assert(classifyStatement("ALTER TABLE users ADD COLUMN email VARCHAR(100)"), SQLStatementAlterTable) + t.Assert(classifyStatement("ALTER TABLE users RENAME TO customers"), SQLStatementRenameTable) + t.Assert(classifyStatement("DROP TABLE IF EXISTS users"), SQLStatementDropTable) + t.Assert(classifyStatement("RENAME TABLE old_name TO new_name"), SQLStatementRenameTable) + t.Assert(classifyStatement("COMMENT ON COLUMN users.name IS 'User name'"), SQLStatementComment) + t.Assert(classifyStatement("SELECT * FROM users"), SQLStatementUnknown) + t.Assert(classifyStatement("INSERT INTO users VALUES (1)"), SQLStatementUnknown) + }) +} + +func Test_unquoteIdentifier(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + t.Assert(unquoteIdentifier("`users`"), "users") + t.Assert(unquoteIdentifier(`"users"`), "users") + t.Assert(unquoteIdentifier("[users]"), "users") + t.Assert(unquoteIdentifier("users"), "users") + }) +} + +func Test_extractTableName(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + t.Assert(extractTableName("CREATE TABLE users"), "users") + t.Assert(extractTableName("CREATE TABLE IF NOT EXISTS users"), "users") + t.Assert(extractTableName("CREATE TABLE `users`"), "users") + t.Assert(extractTableName("CREATE TABLE mydb.users"), "users") + t.Assert(extractTableName("CREATE TEMPORARY TABLE temp_users"), "temp_users") + }) +} + +func Test_applyDropTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + tables := map[string]map[string]*gdb.TableField{ + "users": {}, + "logs": {}, + } + applyDropTable("DROP TABLE IF EXISTS users", tables) + t.Assert(len(tables), 1) + _, ok := tables["users"] + t.Assert(ok, false) + }) +} + +func Test_applyRenameTable_MySQL(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + tables := map[string]map[string]*gdb.TableField{ + "old_name": {"id": {Index: 0, Name: "id", Type: "int"}}, + } + applyRenameTable("RENAME TABLE old_name TO new_name", tables) + t.Assert(len(tables), 1) + _, ok := tables["new_name"] + t.Assert(ok, true) + }) +} + +func Test_applyRenameTable_PgSQL(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + tables := map[string]map[string]*gdb.TableField{ + "old_name": {"id": {Index: 0, Name: "id", Type: "int"}}, + } + applyRenameTable("ALTER TABLE old_name RENAME TO new_name", tables) + t.Assert(len(tables), 1) + _, ok := tables["new_name"] + t.Assert(ok, true) + }) +} + +// =========================== +// Abnormal/edge-case parsing tests +// =========================== + +func Test_processSQL_OnlyDMLStatements(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +INSERT INTO users (id, name) VALUES (1, 'Alice'); +INSERT INTO users (id, name) VALUES (2, 'Bob'); +DELETE FROM users WHERE id = 1; +UPDATE users SET name = 'Charlie' WHERE id = 2; +SELECT * FROM users; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_EmptySQL(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + tables := make(map[string]map[string]*gdb.TableField) + + // Empty string + err := processSQL(parser, "", tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + + // Only whitespace and newlines + err = processSQL(parser, " \n\n \t ", tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_OnlyComments(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +-- This is a line comment +/* This is a block comment */ +-- Another comment +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_AlterNonExistentTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +ALTER TABLE non_existent ADD COLUMN email VARCHAR(200); +ALTER TABLE non_existent DROP COLUMN name; +ALTER TABLE non_existent MODIFY COLUMN name VARCHAR(200); +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_DropNonExistentTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := `DROP TABLE IF EXISTS non_existent;` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_MixedDDLAndDML(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +INSERT INTO logs (msg) VALUES ('starting migration'); +CREATE TABLE users ( + id INT NOT NULL AUTO_INCREMENT, + name VARCHAR(100) NOT NULL, + PRIMARY KEY (id) +); +INSERT INTO users (name) VALUES ('Alice'); +ALTER TABLE users ADD COLUMN email VARCHAR(200); +UPDATE users SET email = 'alice@example.com' WHERE id = 1; +DELETE FROM logs WHERE msg = 'starting migration'; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + // Only DDL statements should be processed; DML should be skipped. + t.Assert(len(tables), 1) + fields := tables["users"] + t.Assert(len(fields), 3) + t.Assert(fields["id"].Key, "PRI") + t.Assert(fields["email"].Name, "email") + }) +} + +func Test_processSQL_CommentOnNonExistentTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &PgSQLParser{} + sql := `COMMENT ON COLUMN non_existent.col1 IS 'some comment';` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_RenameNonExistentTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := `RENAME TABLE non_existent TO new_name;` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + t.Assert(len(tables), 0) + }) +} + +func Test_processSQL_DropColumnFromNonExistentTable(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + parser := &MySQLParser{} + sql := ` +CREATE TABLE users (id INT, name VARCHAR(100), PRIMARY KEY (id)); +ALTER TABLE orders DROP COLUMN status; +` + tables := make(map[string]map[string]*gdb.TableField) + err := processSQL(parser, sql, tables) + t.AssertNil(err) + // users table should still exist, orders ALTER should be silently ignored. + t.Assert(len(tables), 1) + t.Assert(len(tables["users"]), 2) + }) +} + +// =========================== +// CheckLocalTypeForFieldType Tests +// =========================== + +func Test_CheckLocalTypeForFieldType(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + tests := []struct { + fieldType string + expected string + }{ + {"int(10)", "int"}, + {"int(10) unsigned", "uint"}, + {"bigint(20)", "int64"}, + {"bigint(20) unsigned", "uint64"}, + {"tinyint(1)", "int"}, + {"varchar(100)", "string"}, + {"text", "string"}, + {"datetime", "datetime"}, + {"timestamp", "datetime"}, + {"timestamptz", "datetime"}, + {"date", "date"}, + {"time", "time"}, + {"json", "json"}, + {"jsonb", "jsonb"}, + {"float", "float64"}, + {"double", "float64"}, + {"decimal(10,2)", "string"}, + {"bool", "bool"}, + {"boolean", "bool"}, + {"blob", "[]byte"}, + {"binary(16)", "[]byte"}, + {"bit(1)", "bool"}, + } + for _, tt := range tests { + localType, err := gdb.CheckLocalTypeForFieldType(tt.fieldType) + t.AssertNil(err) + t.Assert(string(localType), tt.expected) + } + }) +} diff --git a/cmd/gf/internal/cmd/gendao/gendao_structure.go b/cmd/gf/internal/cmd/gendao/gendao_structure.go index 7601bc616..d5a51aae4 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_structure.go +++ b/cmd/gf/internal/cmd/gendao/gendao_structure.go @@ -20,14 +20,20 @@ import ( "github.com/gogf/gf/v2/text/gstr" ) +// generateStructDefinitionInput holds parameters for generating a Go struct definition +// from database table fields. type generateStructDefinitionInput struct { CGenDaoInternalInput - TableName string // Table name. - StructName string // Struct name. - FieldMap map[string]*gdb.TableField // Table field map. - IsDo bool // Is generating DTO struct. + TableName string // Original database table name. + StructName string // Go struct name (CamelCase of table name). + FieldMap map[string]*gdb.TableField // Map of column name to field metadata. + IsDo bool // Whether generating a DO struct (uses g.Meta orm tag). } +// generateStructDefinition generates a complete Go struct definition string from table fields. +// It returns the struct source code and a list of additional import paths needed +// by custom type mappings. The fields are rendered in a table-aligned format +// using tablewriter for consistent code formatting. func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) (string, []string) { var appendImports []string buffer := bytes.NewBuffer(nil) @@ -59,6 +65,10 @@ func generateStructDefinition(ctx context.Context, in generateStructDefinitionIn return buffer.String(), appendImports } +// getTypeMappingInfo looks up a database field type in the type mapping configuration. +// It handles exact matches first, then tries to extract the base type name from +// parameterized types like "varchar(255)" or "numeric(10,2) unsigned". +// Returns the mapped Go type name and its import path (if any). func getTypeMappingInfo( ctx context.Context, fieldType string, inTypeMapping map[DBFieldTypeName]CustomAttributeType, ) (typeNameStr, importStr string) { @@ -105,9 +115,17 @@ func generateStructFieldDefinition( } if localTypeNameStr == "" { - localTypeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil) - if err != nil { - panic(err) + if in.DB != nil { + localTypeName, err = in.DB.CheckLocalTypeForField(ctx, field.Type, nil) + if err != nil { + panic(err) + } + } else { + // SQL file mode: use standalone type checking without database connection. + localTypeName, err = gdb.CheckLocalTypeForFieldType(field.Type) + if err != nil { + panic(err) + } } localTypeNameStr = string(localTypeName) switch localTypeName { @@ -181,11 +199,12 @@ func generateStructFieldDefinition( return attrLines, appendImport } +// FieldNameCase defines the naming convention for converting field names to Go identifiers. type FieldNameCase string const ( - FieldNameCaseCamel FieldNameCase = "CaseCamel" - FieldNameCaseCamelLower FieldNameCase = "CaseCamelLower" + FieldNameCaseCamel FieldNameCase = "CaseCamel" // PascalCase: "user_name" -> "UserName" + FieldNameCaseCamelLower FieldNameCase = "CaseCamelLower" // camelCase: "user_name" -> "userName" ) // formatFieldName formats and returns a new field name that is used for golang codes generating. diff --git a/cmd/gf/internal/cmd/gendao/gendao_table.go b/cmd/gf/internal/cmd/gendao/gendao_table.go index 4af71fc1c..f444cf7b5 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_table.go +++ b/cmd/gf/internal/cmd/gendao/gendao_table.go @@ -62,7 +62,7 @@ type generateTableSingleInput struct { // generateTableSingle generates dao files for a single table. func generateTableSingle(ctx context.Context, in generateTableSingleInput) { // Generating table data preparing. - fieldMap, err := in.DB.TableFields(ctx, in.TableName) + fieldMap, err := getTableFields(ctx, in.CGenDaoInternalInput, in.TableName) if err != nil { mlog.Fatalf(`fetching tables fields failed for table "%s": %+v`, in.TableName, err) } diff --git a/cmd/gf/internal/cmd/gendao/gendao_tag.go b/cmd/gf/internal/cmd/gendao/gendao_tag.go index 25bc84258..bdd01567b 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_tag.go +++ b/cmd/gf/internal/cmd/gendao/gendao_tag.go @@ -74,6 +74,8 @@ CONFIGURATION SUPPORT CGenDaoBriefTypeMapping = `custom local type mapping for generated struct attributes relevant to fields of table` CGenDaoBriefFieldMapping = `custom local type mapping for generated struct attributes relevant to specific fields of table` CGenDaoBriefShardingPattern = `sharding pattern for table name, e.g. "users_?" will be replace tables "users_001,users_002,..." to "users" dao` + CGenDaoBriefSqlDir = `directory path of SQL DDL files for generating dao/do/entity without database connection` + CGenDaoBriefSqlType = `SQL dialect type when using sqlDir, options: mysql|pgsql|mssql|oracle|sqlite, default is "mysql"` CGenDaoBriefGroup = ` specifying the configuration group name of database for generated ORM instance, it's not necessary and the default value is "default" @@ -95,21 +97,23 @@ generated json tag case for model struct, cases are as follows: CGenDaoBriefTplDaoDoPathPath = `template file path for dao do file` CGenDaoBriefTplDaoEntityPath = `template file path for dao entity file` - tplVarTableName = `TplTableName` - tplVarTableNameCamelCase = `TplTableNameCamelCase` - tplVarTableNameCamelLowerCase = `TplTableNameCamelLowerCase` - tplVarTableSharding = `TplTableSharding` - tplVarTableShardingPrefix = `TplTableShardingPrefix` - tplVarTableFields = `TplTableFields` - tplVarPackageImports = `TplPackageImports` - tplVarImportPrefix = `TplImportPrefix` - tplVarStructDefine = `TplStructDefine` - tplVarColumnDefine = `TplColumnDefine` - tplVarColumnNames = `TplColumnNames` - tplVarGroupName = `TplGroupName` - tplVarDatetimeStr = `TplDatetimeStr` - tplVarCreatedAtDatetimeStr = `TplCreatedAtDatetimeStr` - tplVarPackageName = `TplPackageName` + // Template variable names used by gview for rendering Go file templates. + // These are passed to tplView.Assigns() and referenced in template files. + tplVarTableName = `TplTableName` // Original database table name. + tplVarTableNameCamelCase = `TplTableNameCamelCase` // PascalCase table name (e.g., "UserDetail"). + tplVarTableNameCamelLowerCase = `TplTableNameCamelLowerCase` // camelCase table name (e.g., "userDetail"). + tplVarTableSharding = `TplTableSharding` // Boolean: whether this is a sharding table. + tplVarTableShardingPrefix = `TplTableShardingPrefix` // Sharding table name prefix (e.g., "user_"). + tplVarTableFields = `TplTableFields` // Generated table field definitions. + tplVarPackageImports = `TplPackageImports` // Generated import block string. + tplVarImportPrefix = `TplImportPrefix` // Go import path prefix for internal dao package. + tplVarStructDefine = `TplStructDefine` // Generated struct definition string. + tplVarColumnDefine = `TplColumnDefine` // Column struct field definitions for dao internal. + tplVarColumnNames = `TplColumnNames` // Column name-to-string assignments for dao internal. + tplVarGroupName = `TplGroupName` // Database configuration group name. + tplVarDatetimeStr = `TplDatetimeStr` // Current datetime string for file headers. + tplVarCreatedAtDatetimeStr = `TplCreatedAtDatetimeStr` // "Created at " string (empty if WithTime is false). + tplVarPackageName = `TplPackageName` // Go package name for the generated file. ) func init() { @@ -145,6 +149,8 @@ func init() { `CGenDaoBriefTypeMapping`: CGenDaoBriefTypeMapping, `CGenDaoBriefFieldMapping`: CGenDaoBriefFieldMapping, `CGenDaoBriefShardingPattern`: CGenDaoBriefShardingPattern, + `CGenDaoBriefSqlDir`: CGenDaoBriefSqlDir, + `CGenDaoBriefSqlType`: CGenDaoBriefSqlType, `CGenDaoBriefGroup`: CGenDaoBriefGroup, `CGenDaoBriefJsonCase`: CGenDaoBriefJsonCase, `CGenDaoBriefTplDaoIndexPath`: CGenDaoBriefTplDaoIndexPath, diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index cc99d8e5a..6ee5a3838 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -864,6 +864,24 @@ const ( fieldTypeTimestampz = "timestamptz" fieldTypeJson = "json" fieldTypeJsonb = "jsonb" + + // PostgreSQL specific types. + fieldTypeInt2 = "int2" + fieldTypeInt4 = "int4" + fieldTypeInteger = "integer" + fieldTypeInt8 = "int8" + fieldTypeFloat4 = "float4" + fieldTypeFloat8 = "float8" + fieldTypeDoublePrecision = "double precision" + fieldTypeBoolean = "boolean" + + // Oracle specific types. + fieldTypeNumber = "number" + + // MSSQL specific types. + fieldTypeDatetime2 = "datetime2" + fieldTypeDatetimeOffset = "datetimeoffset" + fieldTypeSmalldatetime = "smalldatetime" ) var ( diff --git a/database/gdb/gdb_core_structure.go b/database/gdb/gdb_core_structure.go index 65176aeeb..4ca2fb41d 100644 --- a/database/gdb/gdb_core_structure.go +++ b/database/gdb/gdb_core_structure.go @@ -226,6 +226,13 @@ Default: // GetFormattedDBTypeNameForField retrieves and returns the formatted database type name // eg. `int(10) unsigned` -> `int`, `varchar(100)` -> `varchar`, etc. func (c *Core) GetFormattedDBTypeNameForField(fieldType string) (typeName, typePattern string) { + return FormatDBTypeName(fieldType) +} + +// FormatDBTypeName retrieves and returns the formatted database type name and pattern +// from raw field type string without requiring a database connection. +// eg. `int(10) unsigned` -> (`int`, `10`), `varchar(100)` -> (`varchar`, `100`). +func FormatDBTypeName(fieldType string) (typeName, typePattern string) { match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType) if len(match) == 3 { typeName = gstr.Trim(match[1]) @@ -246,11 +253,17 @@ func (c *Core) GetFormattedDBTypeNameForField(fieldType string) (typeName, typeP // The `fieldType` is retrieved from ColumnTypes of db driver, example: // UNSIGNED INT func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ any) (LocalType, error) { + return CheckLocalTypeForFieldType(fieldType) +} + +// CheckLocalTypeForFieldType checks and returns corresponding local type for given db field type string +// without requiring a database connection. +func CheckLocalTypeForFieldType(fieldType string) (LocalType, error) { var ( typeName string typePattern string ) - typeName, typePattern = c.GetFormattedDBTypeNameForField(fieldType) + typeName, typePattern = FormatDBTypeName(fieldType) switch typeName { case fieldTypeBinary, @@ -268,7 +281,10 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a fieldTypeSmallint, fieldTypeMediumInt, fieldTypeMediumint, - fieldTypeSerial: + fieldTypeSerial, + fieldTypeInt2, + fieldTypeInt4, + fieldTypeInteger: if gstr.ContainsI(fieldType, "unsigned") { return LocalTypeUint, nil } @@ -277,7 +293,8 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a case fieldTypeBigInt, fieldTypeBigint, - fieldTypeBigserial: + fieldTypeBigserial, + fieldTypeInt8: if gstr.ContainsI(fieldType, "unsigned") { return LocalTypeUint64, nil } @@ -298,11 +315,15 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a fieldTypeDecimal, fieldTypeMoney, fieldTypeNumeric, - fieldTypeSmallmoney: + fieldTypeSmallmoney, + fieldTypeNumber: return LocalTypeString, nil case fieldTypeFloat, - fieldTypeDouble: + fieldTypeDouble, + fieldTypeFloat4, + fieldTypeFloat8, + fieldTypeDoublePrecision: return LocalTypeFloat64, nil case @@ -317,7 +338,8 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a return LocalTypeInt64Bytes, nil case - fieldTypeBool: + fieldTypeBool, + fieldTypeBoolean: return LocalTypeBool, nil case @@ -331,7 +353,10 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a case fieldTypeDatetime, fieldTypeTimestamp, - fieldTypeTimestampz: + fieldTypeTimestampz, + fieldTypeDatetime2, + fieldTypeDatetimeOffset, + fieldTypeSmalldatetime: return LocalTypeDatetime, nil case @@ -345,7 +370,10 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a default: // Auto-detect field type, using key match. switch { - case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"): + case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || + strings.Contains(typeName, "character") || strings.Contains(typeName, "clob") || + strings.Contains(typeName, "ntext") || strings.Contains(typeName, "xml") || + strings.Contains(typeName, "string"): return LocalTypeString, nil case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"): @@ -354,7 +382,9 @@ func (c *Core) CheckLocalTypeForField(ctx context.Context, fieldType string, _ a case strings.Contains(typeName, "bool"): return LocalTypeBool, nil - case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"): + case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob") || + strings.Contains(typeName, "bytea") || strings.Contains(typeName, "image") || + strings.Contains(typeName, "raw"): return LocalTypeBytes, nil case strings.Contains(typeName, "int"):