improve field type check from db to golang (#2023)

This commit is contained in:
John Guo
2022-07-22 16:44:24 +08:00
committed by GitHub
parent b7794a8783
commit 863bea1ad1
9 changed files with 236 additions and 182 deletions

View File

@ -32,7 +32,7 @@ func generateDo(ctx context.Context, db gdb.DB, tableNames, newTableNames []stri
var (
newTableName = newTableNames[i]
doFilePath = gfile.Join(doDirPath, gstr.CaseSnake(newTableName)+".go")
structDefinition = generateStructDefinition(generateStructDefinitionInput{
structDefinition = generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,

View File

@ -28,7 +28,7 @@ func generateEntity(ctx context.Context, db gdb.DB, tableNames, newTableNames []
in,
newTableName,
gstr.CaseCamel(newTableName),
generateStructDefinition(generateStructDefinitionInput{
generateStructDefinition(ctx, generateStructDefinitionInput{
CGenDaoInternalInput: in,
StructName: gstr.CaseCamel(newTableName),
FieldMap: fieldMap,

View File

@ -2,8 +2,8 @@ package gendao
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
@ -19,13 +19,20 @@ type generateStructDefinitionInput struct {
IsDo bool // Is generating DTO struct.
}
func generateStructDefinition(in generateStructDefinitionInput) string {
const (
typeDate = "date"
typeDatetime = "datetime"
typeInt64Bytes = "int64-bytes"
typeUint64Bytes = "uint64-bytes"
)
func generateStructDefinition(ctx context.Context, in generateStructDefinitionInput) string {
buffer := bytes.NewBuffer(nil)
array := make([][]string, len(in.FieldMap))
names := sortFieldKeyForDao(in.FieldMap)
for index, name := range names {
field := in.FieldMap[name]
array[index] = generateStructFieldDefinition(field, in)
array[index] = generateStructFieldDefinition(ctx, field, in)
}
tw := tablewriter.NewWriter(buffer)
tw.SetBorder(false)
@ -50,92 +57,39 @@ func generateStructDefinition(in generateStructDefinitionInput) string {
}
// generateStructFieldForModel generates and returns the attribute definition for specified field.
func generateStructFieldDefinition(field *gdb.TableField, in generateStructDefinitionInput) []string {
func generateStructFieldDefinition(
ctx context.Context, field *gdb.TableField, in generateStructDefinitionInput,
) []string {
var (
err error
typeName string
jsonTag = getJsonTagFromCase(field.Name, in.JsonCase)
)
t, _ := gregex.ReplaceString(`\(.+\)`, "", field.Type)
t = gstr.Split(gstr.Trim(t), " ")[0]
t = gstr.ToLower(t)
switch t {
case "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob":
typeName = "[]byte"
case "bit", "int", "int2", "tinyint", "small_int", "smallint", "medium_int", "mediumint", "serial":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "uint"
} else {
typeName = "int"
}
case "int4", "int8", "big_int", "bigint", "bigserial":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "uint64"
} else {
typeName = "int64"
}
// pgsql int32 slice.
case "_int2":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "[]uint"
} else {
typeName = "[]int"
}
// pgsql int64 slice.
case "_int4", "_int8":
if gstr.ContainsI(field.Type, "unsigned") {
typeName = "[]uint64"
} else {
typeName = "[]int64"
}
case "real":
typeName = "float32"
case "float", "double", "decimal", "smallmoney", "numeric":
typeName = "float64"
case "bool":
typeName = "bool"
case "datetime", "timestamp", "date", "time":
typeName, err = gdb.CheckValueForLocalType(ctx, field.Type, nil)
if err != nil {
panic(err)
}
switch typeName {
case typeDate, typeDatetime:
if in.StdTime {
typeName = "time.Time"
} else {
typeName = "*gtime.Time"
}
case typeInt64Bytes:
typeName = "int64"
case typeUint64Bytes:
typeName = "uint64"
// Special type handle.
case "json", "jsonb":
if in.GJsonSupport {
typeName = "*gjson.Json"
} else {
typeName = "string"
}
default:
// Automatically detect its data type.
switch {
case strings.Contains(t, "int"):
typeName = "int"
case strings.Contains(t, "text") || strings.Contains(t, "char"):
typeName = "string"
case strings.Contains(t, "float") || strings.Contains(t, "double"):
typeName = "float64"
case strings.Contains(t, "bool"):
typeName = "bool"
case strings.Contains(t, "binary") || strings.Contains(t, "blob"):
typeName = "[]byte"
case strings.Contains(t, "date") || strings.Contains(t, "time"):
if in.StdTime {
typeName = "time.Time"
} else {
typeName = "*gtime.Time"
}
default:
typeName = "string"
}
}
var (

View File

@ -1190,24 +1190,26 @@ func Test_DB_TableField(t *testing.T) {
"field_varchar": "abc",
"field_varbinary": "aaa",
}
res, err := db.Model(name).Data(data).Insert()
if err != nil {
gtest.Fatal(err)
}
gtest.C(t, func(t *gtest.T) {
res, err := db.Model(name).Data(data).Insert()
if err != nil {
t.Fatal(err)
}
n, err := res.RowsAffected()
if err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 1)
}
n, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
} else {
t.Assert(n, 1)
}
result, err := db.Model(name).Fields("*").Where("field_int = ?", 2).All()
if err != nil {
gtest.Fatal(err)
}
result, err := db.Model(name).Fields("*").Where("field_int = ?", 2).All()
if err != nil {
t.Fatal(err)
}
t.Assert(result[0], data)
})
gtest.Assert(result[0], data)
}
func Test_DB_Prefix(t *testing.T) {

View File

@ -17,8 +17,6 @@ import (
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/internal/json"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/gutil"
)
@ -130,53 +128,42 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
if fieldType == "" {
return fieldValue, nil
}
typeName, _ := gregex.ReplaceString(`\(.+\)`, "", fieldType)
typeName = strings.ToLower(typeName)
typeName, err := CheckValueForLocalType(ctx, fieldType, fieldValue)
if err != nil {
return nil, err
}
switch typeName {
case
"binary",
"varbinary",
"blob",
"tinyblob",
"mediumblob",
"longblob":
case typeBytes:
if strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob") {
return fieldValue, nil
}
return gconv.Bytes(fieldValue), nil
case
"int",
"tinyint",
"small_int",
"smallint",
"medium_int",
"mediumint",
"serial":
if gstr.ContainsI(fieldType, "unsigned") {
return gconv.Uint(gconv.String(fieldValue)), nil
}
case typeInt:
return gconv.Int(gconv.String(fieldValue)), nil
case
"big_int",
"bigint",
"bigserial":
if gstr.ContainsI(fieldType, "unsigned") {
return gconv.Uint64(gconv.String(fieldValue)), nil
}
case typeUint:
return gconv.Uint(gconv.String(fieldValue)), nil
case typeInt64:
return gconv.Int64(gconv.String(fieldValue)), nil
case "real":
case typeUint64:
return gconv.Uint64(gconv.String(fieldValue)), nil
case typeInt64Bytes:
return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil
case typeUint64Bytes:
return gbinary.BeDecodeToUint64(gconv.Bytes(fieldValue)), nil
case typeFloat32:
return gconv.Float32(gconv.String(fieldValue)), nil
case
"float",
"double",
"decimal",
"money",
"numeric",
"smallmoney":
case typeFloat64:
return gconv.Float64(gconv.String(fieldValue)), nil
case "bit":
case typeBool:
s := gconv.String(fieldValue)
// mssql is true|false string.
if strings.EqualFold(s, "true") {
@ -185,12 +172,9 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
if strings.EqualFold(s, "false") {
return 0, nil
}
return gbinary.BeDecodeToInt64(gconv.Bytes(fieldValue)), nil
case "bool":
return gconv.Bool(fieldValue), nil
case "date":
case typeDate:
// Date without time.
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t).Format("Y-m-d"), nil
@ -198,10 +182,7 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
t, _ := gtime.StrToTime(gconv.String(fieldValue))
return t.Format("Y-m-d"), nil
case
"datetime",
"timestamp",
"timestamptz":
case typeDatetime:
if t, ok := fieldValue.(time.Time); ok {
return gtime.NewFromTime(t), nil
}
@ -209,42 +190,7 @@ func (c *Core) ConvertValueForLocal(ctx context.Context, fieldType string, field
return t, nil
default:
// Auto-detect field type, using key match.
switch {
case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"):
return gconv.String(fieldValue), nil
case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"):
return gconv.Float64(gconv.String(fieldValue)), nil
case strings.Contains(typeName, "bool"):
return gconv.Bool(gconv.String(fieldValue)), nil
case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"):
return fieldValue, nil
case strings.Contains(typeName, "int"):
return gconv.Int(gconv.String(fieldValue)), nil
case strings.Contains(typeName, "time"):
s := gconv.String(fieldValue)
t, err := gtime.StrToTime(s)
if err != nil {
return s, nil
}
return t, nil
case strings.Contains(typeName, "date"):
s := gconv.String(fieldValue)
t, err := gtime.StrToTime(s)
if err != nil {
return s, nil
}
return t, nil
default:
return gconv.String(fieldValue), nil
}
return gconv.String(fieldValue), nil
}
}

View File

@ -124,7 +124,7 @@ func (c *Core) QuoteString(s string) string {
// if true it does nothing to the table name, or else adds the prefix to the table name.
func (c *Core) QuotePrefixTableName(table string) string {
charLeft, charRight := c.db.GetChars()
return doHandleTableName(table, c.db.GetPrefix(), charLeft, charRight)
return doQuoteTableName(table, c.db.GetPrefix(), charLeft, charRight)
}
// GetChars returns the security char for current database.

View File

@ -42,11 +42,6 @@ type iInterfaces interface {
Interfaces() []interface{}
}
// iMapStrAny is the interface support for converting struct parameter to map.
type iMapStrAny interface {
MapStrAny() map[string]interface{}
}
// iTableName is the interface for retrieving table name fro struct.
type iTableName interface {
TableName() string
@ -187,7 +182,7 @@ func DataToMapDeep(value interface{}) map[string]interface{} {
//
// Note that, this will automatically check the table prefix whether already added, if true it does
// nothing to the table name, or else adds the prefix to the table name and returns new table name with prefix.
func doHandleTableName(table, prefix, charLeft, charRight string) string {
func doQuoteTableName(table, prefix, charLeft, charRight string) string {
var (
index = 0
chars = charLeft + charRight

View File

@ -0,0 +1,157 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gdb
import (
"context"
"strings"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
const (
typeString = "string"
typeDate = "date"
typeDatetime = "datetime"
typeInt = "int"
typeUint = "uint"
typeInt64 = "int64"
typeUint64 = "uint64"
typeInt64Slice = "[]int64"
typeUint64Slice = "[]uint64"
typeInt64Bytes = "int64-bytes"
typeUint64Bytes = "uint64-bytes"
typeFloat32 = "float32"
typeFloat64 = "float64"
typeBytes = "[]byte"
typeBool = "bool"
)
// CheckValueForLocalType checks and returns corresponding type for given db type.
func CheckValueForLocalType(ctx context.Context, fieldType string, fieldValue interface{}) (string, error) {
var (
typeName string
typePattern string
)
match, _ := gregex.MatchString(`(.+?)\((.+)\)`, fieldType)
if len(match) == 3 {
typeName = gstr.Trim(match[1])
typePattern = gstr.Trim(match[2])
} else {
typeName = fieldType
}
typeName = strings.ToLower(typeName)
switch typeName {
case
"binary",
"varbinary",
"blob",
"tinyblob",
"mediumblob",
"longblob":
return typeBytes, nil
case
"int",
"tinyint",
"small_int",
"smallint",
"medium_int",
"mediumint",
"serial":
if typePattern == "1" {
return typeBool, nil
}
if gstr.ContainsI(fieldType, "unsigned") {
return typeUint, nil
}
return typeInt, nil
case "_int4", "_int8":
if gstr.ContainsI(fieldType, "unsigned") {
return typeUint64Slice, nil
}
return typeInt64Slice, nil
case
"big_int",
"bigint",
"bigserial":
if gstr.ContainsI(fieldType, "unsigned") {
return typeUint64, nil
}
return typeInt64, nil
case "real":
return typeFloat32, nil
case
"float",
"double",
"decimal",
"money",
"numeric",
"smallmoney":
return typeFloat64, nil
case "bit":
if typePattern == "1" {
return typeBool, nil
}
s := gconv.String(fieldValue)
// mssql is true|false string.
if strings.EqualFold(s, "true") || strings.EqualFold(s, "false") {
return typeBool, nil
}
if gstr.ContainsI(fieldType, "unsigned") {
return typeUint64Bytes, nil
}
return typeInt64Bytes, nil
case "bool":
return typeBool, nil
case "date":
return typeDate, nil
case
"datetime",
"timestamp",
"timestamptz":
return typeDatetime, nil
default:
// Auto-detect field type, using key match.
switch {
case strings.Contains(typeName, "text") || strings.Contains(typeName, "char") || strings.Contains(typeName, "character"):
return typeString, nil
case strings.Contains(typeName, "float") || strings.Contains(typeName, "double") || strings.Contains(typeName, "numeric"):
return typeFloat64, nil
case strings.Contains(typeName, "bool"):
return typeBool, nil
case strings.Contains(typeName, "binary") || strings.Contains(typeName, "blob"):
return typeBytes, nil
case strings.Contains(typeName, "int"):
return typeInt, nil
case strings.Contains(typeName, "time"):
return typeDatetime, nil
case strings.Contains(typeName, "date"):
return typeDatetime, nil
default:
return typeString, nil
}
}
}

View File

@ -72,7 +72,7 @@ func Test_Func_addTablePrefix(t *testing.T) {
"UserCenter..user as u, user_detail as ut": "`UserCenter`..`user` as u,`user_detail` as ut",
}
for k, v := range array {
t.Assert(doHandleTableName(k, prefix, "`", "`"), v)
t.Assert(doQuoteTableName(k, prefix, "`", "`"), v)
}
})
gtest.C(t, func(t *gtest.T) {
@ -91,7 +91,7 @@ func Test_Func_addTablePrefix(t *testing.T) {
"UserCenter..user as u, user_detail as ut": "`UserCenter`..`gf_user` as u,`gf_user_detail` as ut",
}
for k, v := range array {
t.Assert(doHandleTableName(k, prefix, "`", "`"), v)
t.Assert(doQuoteTableName(k, prefix, "`", "`"), v)
}
})
}