diff --git a/cmd/gf/internal/cmd/gendao/gendao_do.go b/cmd/gf/internal/cmd/gendao/gendao_do.go index c7e5152c4..200ce8d9c 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_do.go +++ b/cmd/gf/internal/cmd/gendao/gendao_do.go @@ -32,7 +32,7 @@ func generateDo(ctx context.Context, db gdb.DB, tableNames, newTableNames []stri var ( newTableName = newTableNames[i] doFilePath = gfile.Join(doDirPath, gstr.CaseSnake(newTableName)+".go") - structDefinition = generateStructDefinition(generateStructDefinitionInput{ + structDefinition = generateStructDefinition(ctx, generateStructDefinitionInput{ CGenDaoInternalInput: in, StructName: gstr.CaseCamel(newTableName), FieldMap: fieldMap, diff --git a/cmd/gf/internal/cmd/gendao/gendao_entity.go b/cmd/gf/internal/cmd/gendao/gendao_entity.go index 77ce2e901..232d7afe0 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_entity.go +++ b/cmd/gf/internal/cmd/gendao/gendao_entity.go @@ -28,7 +28,7 @@ func generateEntity(ctx context.Context, db gdb.DB, tableNames, newTableNames [] in, newTableName, gstr.CaseCamel(newTableName), - generateStructDefinition(generateStructDefinitionInput{ + generateStructDefinition(ctx, generateStructDefinitionInput{ CGenDaoInternalInput: in, StructName: gstr.CaseCamel(newTableName), FieldMap: fieldMap, diff --git a/cmd/gf/internal/cmd/gendao/gendao_structure.go b/cmd/gf/internal/cmd/gendao/gendao_structure.go index 99e73cf5a..f15d3efdb 100644 --- a/cmd/gf/internal/cmd/gendao/gendao_structure.go +++ b/cmd/gf/internal/cmd/gendao/gendao_structure.go @@ -2,8 +2,8 @@ package gendao import ( "bytes" + "context" "fmt" - "strings" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" @@ -19,13 +19,20 @@ type generateStructDefinitionInput struct { IsDo bool // Is generating DTO struct. } -func generateStructDefinition(in generateStructDefinitionInput) string { +const ( + typeDate = "date" + typeDatetime = "datetime" + typeInt64Bytes = "int64-bytes" + typeUint64Bytes = "uint64-bytes" +) + +func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) string { buffer := bytes.NewBuffer(nil) array := make([][]string, len(in.FieldMap)) names := sortFieldKeyForDao(in.FieldMap) for index, name := range names { field := in.FieldMap[name] - array[index] = generateStructFieldDefinition(field, in) + array[index] = generateStructFieldDefinition(ctx, field, in) } tw := tablewriter.NewWriter(buffer) tw.SetBorder(false) @@ -50,92 +57,39 @@ func generateStructDefinition(in generateStructDefinitionInput) string { } // generateStructFieldForModel generates and returns the attribute definition for specified field. -func generateStructFieldDefinition(field *gdb.TableField, in generateStructDefinitionInput) []string { +func generateStructFieldDefinition( + ctx context.Context, field *gdb.TableField, in generateStructDefinitionInput, +) []string { var ( + err error typeName string jsonTag = getJsonTagFromCase(field.Name, in.JsonCase) ) - t, _ := gregex.ReplaceString(`\(.+\)`, "", field.Type) - t = gstr.Split(gstr.Trim(t), " ")[0] - t = gstr.ToLower(t) - - switch t { - case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob": - typeName = "[]byte" - - case "bit", "int", "int2", "tinyint", "small_int", "smallint", "medium_int", "mediumint", "serial": - if gstr.ContainsI(field.Type, "unsigned") { - typeName = "uint" - } else { - typeName = "int" - } - - case "int4", "int8", "big_int", "bigint", "bigserial": - if gstr.ContainsI(field.Type, "unsigned") { - typeName = "uint64" - } else { - typeName = "int64" - } - - // pgsql int32 slice. - case "_int2": - if gstr.ContainsI(field.Type, "unsigned") { - typeName = "[]uint" - } else { - typeName = "[]int" - } - - // pgsql int64 slice. - case "_int4", "_int8": - if gstr.ContainsI(field.Type, "unsigned") { - typeName = "[]uint64" - } else { - typeName = "[]int64" - } - - case "real": - typeName = "float32" - - case "float", "double", "decimal", "smallmoney", "numeric": - typeName = "float64" - - case "bool": - typeName = "bool" - - case "datetime", "timestamp", "date", "time": + typeName, err = gdb.CheckValueForLocalType(ctx, field.Type, nil) + if err != nil { + panic(err) + } + switch typeName { + case typeDate, typeDatetime: if in.StdTime { typeName = "time.Time" } else { typeName = "*gtime.Time" } + + case typeInt64Bytes: + typeName = "int64" + + case typeUint64Bytes: + typeName = "uint64" + + // Special type handle. case "json", "jsonb": if in.GJsonSupport { typeName = "*gjson.Json" } else { typeName = "string" } - default: - // Automatically detect its data type. - switch { - case strings.Contains(t, "int"): - typeName = "int" - case strings.Contains(t, "text") || strings.Contains(t, "char"): - typeName = "string" - case strings.Contains(t, "float") || strings.Contains(t, "double"): - typeName = "float64" - case strings.Contains(t, "bool"): - typeName = "bool" - case strings.Contains(t, "binary") || strings.Contains(t, "blob"): - typeName = "[]byte" - case strings.Contains(t, "date") || strings.Contains(t, "time"): - if in.StdTime { - typeName = "time.Time" - } else { - typeName = "*gtime.Time" - } - default: - typeName = "string" - } } var ( diff --git a/contrib/drivers/mysql/mysql_core_test.go b/contrib/drivers/mysql/mysql_core_test.go index 9103fde7c..6f55d4789 100644 --- a/contrib/drivers/mysql/mysql_core_test.go +++ b/contrib/drivers/mysql/mysql_core_test.go @@ -1190,24 +1190,26 @@ func Test_DB_TableField(t *testing.T) { "field_varchar": "abc", "field_varbinary": "aaa", } - res, err := db.Model(name).Data(data).Insert() - if err != nil { - gtest.Fatal(err) - } + gtest.C(t, func(t *gtest.T) { + res, err := db.Model(name).Data(data).Insert() + if err != nil { + t.Fatal(err) + } - n, err := res.RowsAffected() - if err != nil { - gtest.Fatal(err) - } else { - gtest.Assert(n, 1) - } + n, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } else { + t.Assert(n, 1) + } - result, err := db.Model(name).Fields("*").Where("field_int = ?", 2).All() - if err != nil { - gtest.Fatal(err) - } + result, err := db.Model(name).Fields("*").Where("field_int = ?", 2).All() + if err != nil { + t.Fatal(err) + } + t.Assert(result[0], data) + }) - gtest.Assert(result[0], data) } func Test_DB_Prefix(t *testing.T) { diff --git a/database/gdb/gdb_core_structure.go b/database/gdb/gdb_core_structure.go index ee82c0538..c741e4e40 100644 --- a/database/gdb/gdb_core_structure.go +++ b/database/gdb/gdb_core_structure.go @@ -17,8 +17,6 @@ import ( "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/internal/json" "github.com/gogf/gf/v2/os/gtime" - "github.com/gogf/gf/v2/text/gregex" - "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gutil" ) @@ -130,53 +128,42 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field if fieldType == "" { return fieldValue, nil } - typeName, _ := gregex.ReplaceString(`\(.+\)`, "", fieldType) - typeName = strings.ToLower(typeName) + typeName, err := CheckValueForLocalType(ctx, fieldType, fieldValue) + if err != nil { + return nil, err + } switch typeName { - case - "binary", - "varbinary", - "blob", - "tinyblob", - "mediumblob", - "longblob": + case typeBytes: + if strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob") { + return fieldValue, nil + } return gconv.Bytes(fieldValue), nil - case - "int", - "tinyint", - "small_int", - "smallint", - "medium_int", - "mediumint", - "serial": - if gstr.ContainsI(fieldType, "unsigned") { - return gconv.Uint(gconv.String(fieldValue)), nil - } + case typeInt: return gconv.Int(gconv.String(fieldValue)), nil - case - "big_int", - "bigint", - "bigserial": - if gstr.ContainsI(fieldType, "unsigned") { - return gconv.Uint64(gconv.String(fieldValue)), nil - } + case typeUint: + return gconv.Uint(gconv.String(fieldValue)), nil + + case typeInt64: return gconv.Int64(gconv.String(fieldValue)), nil - case "real": + case typeUint64: + return gconv.Uint64(gconv.String(fieldValue)), nil + + case typeInt64Bytes: + return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil + + case typeUint64Bytes: + return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil + + case typeFloat32: return gconv.Float32(gconv.String(fieldValue)), nil - case - "float", - "double", - "decimal", - "money", - "numeric", - "smallmoney": + case typeFloat64: return gconv.Float64(gconv.String(fieldValue)), nil - case "bit": + case typeBool: s := gconv.String(fieldValue) // mssql is true|false string. if strings.EqualFold(s, "true") { @@ -185,12 +172,9 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field if strings.EqualFold(s, "false") { return 0, nil } - return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil - - case "bool": return gconv.Bool(fieldValue), nil - case "date": + case typeDate: // Date without time. if t, ok := fieldValue.(time.Time); ok { return gtime.NewFromTime(t).Format("Y-m-d"), nil @@ -198,10 +182,7 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field t, _ := gtime.StrToTime(gconv.String(fieldValue)) return t.Format("Y-m-d"), nil - case - "datetime", - "timestamp", - "timestamptz": + case typeDatetime: if t, ok := fieldValue.(time.Time); ok { return gtime.NewFromTime(t), nil } @@ -209,42 +190,7 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field return t, nil default: - // Auto-detect field type, using key match. - switch { - case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"): - return gconv.String(fieldValue), nil - - case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"): - return gconv.Float64(gconv.String(fieldValue)), nil - - case strings.Contains(typeName, "bool"): - return gconv.Bool(gconv.String(fieldValue)), nil - - case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"): - return fieldValue, nil - - case strings.Contains(typeName, "int"): - return gconv.Int(gconv.String(fieldValue)), nil - - case strings.Contains(typeName, "time"): - s := gconv.String(fieldValue) - t, err := gtime.StrToTime(s) - if err != nil { - return s, nil - } - return t, nil - - case strings.Contains(typeName, "date"): - s := gconv.String(fieldValue) - t, err := gtime.StrToTime(s) - if err != nil { - return s, nil - } - return t, nil - - default: - return gconv.String(fieldValue), nil - } + return gconv.String(fieldValue), nil } } diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index f8324d5dd..d995e0df5 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -124,7 +124,7 @@ func (c *Core) QuoteString(s string) string { // if true it does nothing to the table name, or else adds the prefix to the table name. func (c *Core) QuotePrefixTableName(table string) string { charLeft, charRight := c.db.GetChars() - return doHandleTableName(table, c.db.GetPrefix(), charLeft, charRight) + return doQuoteTableName(table, c.db.GetPrefix(), charLeft, charRight) } // GetChars returns the security char for current database. diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 0f22ed18d..454d8fa82 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -42,11 +42,6 @@ type iInterfaces interface { Interfaces() []interface{} } -// iMapStrAny is the interface support for converting struct parameter to map. -type iMapStrAny interface { - MapStrAny() map[string]interface{} -} - // iTableName is the interface for retrieving table name fro struct. type iTableName interface { TableName() string @@ -187,7 +182,7 @@ func DataToMapDeep(value interface{}) map[string]interface{} { // // Note that, this will automatically check the table prefix whether already added, if true it does // nothing to the table name, or else adds the prefix to the table name and returns new table name with prefix. -func doHandleTableName(table, prefix, charLeft, charRight string) string { +func doQuoteTableName(table, prefix, charLeft, charRight string) string { var ( index = 0 chars = charLeft + charRight diff --git a/database/gdb/gdb_func_structure.go b/database/gdb/gdb_func_structure.go new file mode 100644 index 000000000..f422a459c --- /dev/null +++ b/database/gdb/gdb_func_structure.go @@ -0,0 +1,157 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "context" + "strings" + + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +const ( + typeString = "string" + typeDate = "date" + typeDatetime = "datetime" + typeInt = "int" + typeUint = "uint" + typeInt64 = "int64" + typeUint64 = "uint64" + typeInt64Slice = "[]int64" + typeUint64Slice = "[]uint64" + typeInt64Bytes = "int64-bytes" + typeUint64Bytes = "uint64-bytes" + typeFloat32 = "float32" + typeFloat64 = "float64" + typeBytes = "[]byte" + typeBool = "bool" +) + +// CheckValueForLocalType checks and returns corresponding type for given db type. +func CheckValueForLocalType(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) { + var ( + typeName string + typePattern string + ) + match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType) + if len(match) == 3 { + typeName = gstr.Trim(match[1]) + typePattern = gstr.Trim(match[2]) + } else { + typeName = fieldType + } + typeName = strings.ToLower(typeName) + switch typeName { + case + "binary", + "varbinary", + "blob", + "tinyblob", + "mediumblob", + "longblob": + return typeBytes, nil + + case + "int", + "tinyint", + "small_int", + "smallint", + "medium_int", + "mediumint", + "serial": + if typePattern == "1" { + return typeBool, nil + } + if gstr.ContainsI(fieldType, "unsigned") { + return typeUint, nil + } + return typeInt, nil + + case "_int4", "_int8": + if gstr.ContainsI(fieldType, "unsigned") { + return typeUint64Slice, nil + } + return typeInt64Slice, nil + + case + "big_int", + "bigint", + "bigserial": + if gstr.ContainsI(fieldType, "unsigned") { + return typeUint64, nil + } + return typeInt64, nil + + case "real": + return typeFloat32, nil + + case + "float", + "double", + "decimal", + "money", + "numeric", + "smallmoney": + return typeFloat64, nil + + case "bit": + if typePattern == "1" { + return typeBool, nil + } + s := gconv.String(fieldValue) + // mssql is true|false string. + if strings.EqualFold(s, "true") || strings.EqualFold(s, "false") { + return typeBool, nil + } + if gstr.ContainsI(fieldType, "unsigned") { + return typeUint64Bytes, nil + } + return typeInt64Bytes, nil + + case "bool": + return typeBool, nil + + case "date": + return typeDate, nil + + case + "datetime", + "timestamp", + "timestamptz": + return typeDatetime, nil + + default: + // Auto-detect field type, using key match. + switch { + case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"): + return typeString, nil + + case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"): + return typeFloat64, nil + + case strings.Contains(typeName, "bool"): + return typeBool, nil + + case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"): + return typeBytes, nil + + case strings.Contains(typeName, "int"): + return typeInt, nil + + case strings.Contains(typeName, "time"): + return typeDatetime, nil + + case strings.Contains(typeName, "date"): + return typeDatetime, nil + + default: + return typeString, nil + } + } +} diff --git a/database/gdb/gdb_z_mysql_internal_test.go b/database/gdb/gdb_z_mysql_internal_test.go index 32a0bce9a..840dd4224 100644 --- a/database/gdb/gdb_z_mysql_internal_test.go +++ b/database/gdb/gdb_z_mysql_internal_test.go @@ -72,7 +72,7 @@ func Test_Func_addTablePrefix(t *testing.T) { "UserCenter..user as u, user_detail as ut": "`UserCenter`..`user` as u,`user_detail` as ut", } for k, v := range array { - t.Assert(doHandleTableName(k, prefix, "`", "`"), v) + t.Assert(doQuoteTableName(k, prefix, "`", "`"), v) } }) gtest.C(t, func(t *gtest.T) { @@ -91,7 +91,7 @@ func Test_Func_addTablePrefix(t *testing.T) { "UserCenter..user as u, user_detail as ut": "`UserCenter`..`gf_user` as u,`gf_user_detail` as ut", } for k, v := range array { - t.Assert(doHandleTableName(k, prefix, "`", "`"), v) + t.Assert(doQuoteTableName(k, prefix, "`", "`"), v) } }) }