diff --git a/.example/other/test.go b/.example/other/test.go index 1ffd72098..3c5d6fd37 100644 --- a/.example/other/test.go +++ b/.example/other/test.go @@ -1,23 +1,12 @@ package main import ( - "encoding/json" + "encoding/hex" "fmt" - "gopkg.in/yaml.v3" ) func main() { - data := []byte(` -m: - k: v - `) - var result map[string]interface{} - if err := yaml.Unmarshal(data, &result); err != nil { - panic(err) - } - b, err := json.Marshal(result) - if err != nil { - panic(err) - } + b := []byte{3, 0, 0} fmt.Println(string(b)) + fmt.Println(hex.EncodeToString(b)) } diff --git a/.travis.yml b/.travis.yml index b9e38f278..5ef9c9afa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,8 @@ language: go go: + - "1.11.x" + - "1.12.x" - "1.13.x" - "1.14.x" diff --git a/DONATOR.MD b/DONATOR.MD index c7a954564..5bc75eea9 100644 --- a/DONATOR.MD +++ b/DONATOR.MD @@ -71,6 +71,8 @@ We currently accept donation by Alipay/WechatPay, please note your github/gitee |[侯哥](http://www.macnie.com)|wechat|¥10.00| |如果🍋|alipay|¥100.00| 错过的奶茶^_^ |蔡蔡|wechat|¥666.00| gf真强大,让项目省心 +|jack|wechat|¥100.00| +|sbilly|wechat|¥100.00| 祝好! diff --git a/README.MD b/README.MD index 9d1336400..2c46baf2c 100644 --- a/README.MD +++ b/README.MD @@ -9,7 +9,7 @@ English | [简体中文](README_ZH.MD) -`GF(GoFrame)` is a modular, full-featured and production-ready application development framework +`GF(GoFrame)` is a modular, high-performance and production-ready application development framework of golang. Providing a series of core components and dozens of practical modules, such as: cache, logging, containers, timer, resource, validator, database orm, etc. Supporting web server integrated with router, cookie, session, middleware, logger, configure, @@ -27,9 +27,67 @@ require github.com/gogf/gf latest # Limitation ``` -golang version >= 1.13 +golang version >= 1.11 ``` +# Architecture +
+ +
+ +# Performance + +Here's the most popular Golang frameworks and libraries performance testing result in `WEB Server`. Performance testing cases source codes are hosted at: https://github.com/gogf/gf-performance + +## Environment + + OS : Ubuntu 18.04 amd64 + CPU : AMD A8-6600K x 4 + MEM : 32GB + GO : v1.13.4 + +## Testing Tool + +`ab`: Apache HTTP server benchmarking tool. + +Command: +``` +ab -t 10 -c 100 http://127.0.0.1:3000/hello +ab -t 10 -c 100 http://127.0.0.1:3000/query?id=10000 +ab -t 10 -c 100 http://127.0.0.1:3000/json +``` +The concurrency starts from `100` to `10000`. + +> Run `5` times for each case of each project and pick up the best testing result. + +## 1. Hello World + + + + + + + + + + + +
ThroughputsMean LatencyP99 Latency
+ +## 2. Json Response + + + + + + + + + + + +
ThroughputsMean LatencyP99 Latency
+ # Documentation * 中文官网: https://goframe.org @@ -43,12 +101,6 @@ golang version >= 1.13 > It's recommended learning `GoFrame` through its awesome source codes and API reference. -# Architecture -
- -
- - # License `GF` is licensed under the [MIT License](LICENSE), 100% free and open-source, forever. diff --git a/README_ZH.MD b/README_ZH.MD index 2ca02be7d..a9ae9c171 100644 --- a/README_ZH.MD +++ b/README_ZH.MD @@ -41,7 +41,7 @@ require github.com/gogf/gf latest # 限制 ```shell -golang版本 >= 1.13 +golang版本 >= 1.11 ``` # 架构 @@ -49,6 +49,59 @@ golang版本 >= 1.13 +# 性能 + +以下是目前最流行的`WEB Server` Golang框架/类库性能测试结果。 +性能测试用例源代码仓库: https://github.com/gogf/gf-performance + +## 环境: + + OS : Ubuntu 18.04 amd64 + CPU : AMD A8-6600K x 4 + MEM : 32GB + GO : v1.13.4 + +## 工具 + +`ab`: Apache HTTP server benchmarking tool. + +测试命令: +``` +ab -t 10 -c 100 http://127.0.0.1:3000/hello +ab -t 10 -c 100 http://127.0.0.1:3000/query?id=10000 +ab -t 10 -c 100 http://127.0.0.1:3000/json +``` +并发客户端数量从 `100` 递增到 `10000`。 + +> 每个项目的每个用例均运行`5`次,取最优的结果展示。 + +## 1. Hello World + + + + + + + + + + + +
ThroughputsMean LatencyP99 Latency
+ +## 2. Json Response + + + + + + + + + + + +
ThroughputsMean LatencyP99 Latency
# 文档 diff --git a/container/garray/garray_func.go b/container/garray/garray_func.go index 7ae0822b3..572cba444 100644 --- a/container/garray/garray_func.go +++ b/container/garray/garray_func.go @@ -8,6 +8,12 @@ package garray import "strings" +// apiInterfaces is used for type assert api for Interfaces. +type apiInterfaces interface { + Interfaces() []interface{} +} + +// defaultComparatorInt for int comparison. func defaultComparatorInt(a, b int) int { if a < b { return -1 @@ -18,6 +24,7 @@ func defaultComparatorInt(a, b int) int { return 0 } +// defaultComparatorStr for string comparison. func defaultComparatorStr(a, b string) int { return strings.Compare(a, b) } diff --git a/container/garray/garray_normal_any.go b/container/garray/garray_normal_any.go index 12d952585..6ffebe89e 100644 --- a/container/garray/garray_normal_any.go +++ b/container/garray/garray_normal_any.go @@ -23,7 +23,7 @@ import ( // Array is a golang array with rich features. type Array struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []interface{} } @@ -44,7 +44,7 @@ func NewArray(safe ...bool) *Array { // which is false in default. func NewArraySize(size int, cap int, safe ...bool) *Array { return &Array{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: make([]interface{}, size, cap), } } @@ -79,7 +79,7 @@ func NewFromCopy(array []interface{}, safe ...bool) *Array { // which is false in default. func NewArrayFrom(array []interface{}, safe ...bool) *Array { return &Array{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: array, } } @@ -91,7 +91,7 @@ func NewArrayFromCopy(array []interface{}, safe ...bool) *Array { newArray := make([]interface{}, len(array)) copy(newArray, array) return &Array{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: newArray, } } @@ -533,23 +533,7 @@ func (a *Array) RLockFunc(f func(array []interface{})) *Array { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *Array) Merge(array interface{}) *Array { - switch v := array.(type) { - case *Array: - a.Append(gconv.Interfaces(v.Slice())...) - case *IntArray: - a.Append(gconv.Interfaces(v.Slice())...) - case *StrArray: - a.Append(gconv.Interfaces(v.Slice())...) - case *SortedArray: - a.Append(gconv.Interfaces(v.Slice())...) - case *SortedIntArray: - a.Append(gconv.Interfaces(v.Slice())...) - case *SortedStrArray: - a.Append(gconv.Interfaces(v.Slice())...) - default: - a.Append(gconv.Interfaces(array)...) - } - return a + return a.Append(gconv.Interfaces(array)...) } // Fill fills an array with num entries of the value , @@ -668,6 +652,9 @@ func (a *Array) Reverse() *Array { func (a *Array) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(gconv.String(v)) @@ -694,7 +681,7 @@ func (a *Array) Iterator(f func(k int, v interface{}) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *Array) IteratorAsc(f func(k int, v interface{}) bool) { a.mu.RLock() @@ -706,7 +693,7 @@ func (a *Array) IteratorAsc(f func(k int, v interface{}) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *Array) IteratorDesc(f func(k int, v interface{}) bool) { a.mu.RLock() @@ -749,8 +736,7 @@ func (a *Array) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (a *Array) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.array == nil { a.array = make([]interface{}, 0) } a.mu.Lock() @@ -763,9 +749,6 @@ func (a *Array) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for array. func (a *Array) UnmarshalValue(value interface{}) error { - if a.mu == nil { - a.mu = rwmutex.New() - } a.mu.Lock() defer a.mu.Unlock() switch value.(type) { @@ -806,6 +789,16 @@ func (a *Array) FilterEmpty() *Array { return a } +// Walk applies a user supplied function to every item of array. +func (a *Array) Walk(f func(value interface{}) interface{}) *Array { + a.mu.Lock() + defer a.mu.Unlock() + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *Array) IsEmpty() bool { return a.Len() == 0 diff --git a/container/garray/garray_normal_int.go b/container/garray/garray_normal_int.go index ecaa20c3f..0cce6d71c 100644 --- a/container/garray/garray_normal_int.go +++ b/container/garray/garray_normal_int.go @@ -21,7 +21,7 @@ import ( // IntArray is a golang int array with rich features. type IntArray struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []int } @@ -37,7 +37,7 @@ func NewIntArray(safe ...bool) *IntArray { // which is false in default. func NewIntArraySize(size int, cap int, safe ...bool) *IntArray { return &IntArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: make([]int, size, cap), } } @@ -62,7 +62,7 @@ func NewIntArrayRange(start, end, step int, safe ...bool) *IntArray { // which is false in default. func NewIntArrayFrom(array []int, safe ...bool) *IntArray { return &IntArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: array, } } @@ -74,7 +74,7 @@ func NewIntArrayFromCopy(array []int, safe ...bool) *IntArray { newArray := make([]int, len(array)) copy(newArray, array) return &IntArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: newArray, } } @@ -535,23 +535,7 @@ func (a *IntArray) RLockFunc(f func(array []int)) *IntArray { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *IntArray) Merge(array interface{}) *IntArray { - switch v := array.(type) { - case *Array: - a.Append(gconv.Ints(v.Slice())...) - case *IntArray: - a.Append(gconv.Ints(v.Slice())...) - case *StrArray: - a.Append(gconv.Ints(v.Slice())...) - case *SortedArray: - a.Append(gconv.Ints(v.Slice())...) - case *SortedIntArray: - a.Append(gconv.Ints(v.Slice())...) - case *SortedStrArray: - a.Append(gconv.Ints(v.Slice())...) - default: - a.Append(gconv.Ints(array)...) - } - return a + return a.Append(gconv.Ints(array)...) } // Fill fills an array with num entries of the value , @@ -670,6 +654,9 @@ func (a *IntArray) Reverse() *IntArray { func (a *IntArray) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(gconv.String(v)) @@ -696,7 +683,7 @@ func (a *IntArray) Iterator(f func(k int, v int) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *IntArray) IteratorAsc(f func(k int, v int) bool) { a.mu.RLock() @@ -708,7 +695,7 @@ func (a *IntArray) IteratorAsc(f func(k int, v int) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *IntArray) IteratorDesc(f func(k int, v int) bool) { a.mu.RLock() @@ -734,8 +721,7 @@ func (a *IntArray) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (a *IntArray) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.array == nil { a.array = make([]int, 0) } a.mu.Lock() @@ -748,9 +734,6 @@ func (a *IntArray) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for array. func (a *IntArray) UnmarshalValue(value interface{}) error { - if a.mu == nil { - a.mu = rwmutex.New() - } a.mu.Lock() defer a.mu.Unlock() switch value.(type) { @@ -776,6 +759,16 @@ func (a *IntArray) FilterEmpty() *IntArray { return a } +// Walk applies a user supplied function to every item of array. +func (a *IntArray) Walk(f func(value int) int) *IntArray { + a.mu.Lock() + defer a.mu.Unlock() + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *IntArray) IsEmpty() bool { return a.Len() == 0 diff --git a/container/garray/garray_normal_str.go b/container/garray/garray_normal_str.go index 78fc98e0e..d24038a27 100644 --- a/container/garray/garray_normal_str.go +++ b/container/garray/garray_normal_str.go @@ -23,7 +23,7 @@ import ( // StrArray is a golang string array with rich features. type StrArray struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []string } @@ -39,7 +39,7 @@ func NewStrArray(safe ...bool) *StrArray { // which is false in default. func NewStrArraySize(size int, cap int, safe ...bool) *StrArray { return &StrArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: make([]string, size, cap), } } @@ -49,7 +49,7 @@ func NewStrArraySize(size int, cap int, safe ...bool) *StrArray { // which is false in default. func NewStrArrayFrom(array []string, safe ...bool) *StrArray { return &StrArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: array, } } @@ -61,7 +61,7 @@ func NewStrArrayFromCopy(array []string, safe ...bool) *StrArray { newArray := make([]string, len(array)) copy(newArray, array) return &StrArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: newArray, } } @@ -526,23 +526,7 @@ func (a *StrArray) RLockFunc(f func(array []string)) *StrArray { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *StrArray) Merge(array interface{}) *StrArray { - switch v := array.(type) { - case *Array: - a.Append(gconv.Strings(v.Slice())...) - case *IntArray: - a.Append(gconv.Strings(v.Slice())...) - case *StrArray: - a.Append(gconv.Strings(v.Slice())...) - case *SortedArray: - a.Append(gconv.Strings(v.Slice())...) - case *SortedIntArray: - a.Append(gconv.Strings(v.Slice())...) - case *SortedStrArray: - a.Append(gconv.Strings(v.Slice())...) - default: - a.Append(gconv.Strings(array)...) - } - return a + return a.Append(gconv.Strings(array)...) } // Fill fills an array with num entries of the value , @@ -661,6 +645,9 @@ func (a *StrArray) Reverse() *StrArray { func (a *StrArray) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(v) @@ -687,7 +674,7 @@ func (a *StrArray) Iterator(f func(k int, v string) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *StrArray) IteratorAsc(f func(k int, v string) bool) { a.mu.RLock() @@ -699,7 +686,7 @@ func (a *StrArray) IteratorAsc(f func(k int, v string) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *StrArray) IteratorDesc(f func(k int, v string) bool) { a.mu.RLock() @@ -736,8 +723,7 @@ func (a *StrArray) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (a *StrArray) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.array == nil { a.array = make([]string, 0) } a.mu.Lock() @@ -750,9 +736,6 @@ func (a *StrArray) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for array. func (a *StrArray) UnmarshalValue(value interface{}) error { - if a.mu == nil { - a.mu = rwmutex.New() - } a.mu.Lock() defer a.mu.Unlock() switch value.(type) { @@ -778,6 +761,16 @@ func (a *StrArray) FilterEmpty() *StrArray { return a } +// Walk applies a user supplied function to every item of array. +func (a *StrArray) Walk(f func(value string) string) *StrArray { + a.mu.Lock() + defer a.mu.Unlock() + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *StrArray) IsEmpty() bool { return a.Len() == 0 diff --git a/container/garray/garray_sorted_any.go b/container/garray/garray_sorted_any.go index b4c3b4e63..31d222dac 100644 --- a/container/garray/garray_sorted_any.go +++ b/container/garray/garray_sorted_any.go @@ -16,7 +16,6 @@ import ( "math" "sort" - "github.com/gogf/gf/container/gtype" "github.com/gogf/gf/internal/rwmutex" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/grand" @@ -25,9 +24,9 @@ import ( // SortedArray is a golang sorted array with rich features. // It's using increasing order in default. type SortedArray struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []interface{} - unique *gtype.Bool // Whether enable unique feature(false) + unique bool // Whether enable unique feature(false) comparator func(a, b interface{}) int // Comparison function(it returns -1: a < b; 0: a == b; 1: a > b) } @@ -46,8 +45,7 @@ func NewSortedArray(comparator func(a, b interface{}) int, safe ...bool) *Sorted // which is false in default. func NewSortedArraySize(cap int, comparator func(a, b interface{}) int, safe ...bool) *SortedArray { return &SortedArray{ - mu: rwmutex.New(safe...), - unique: gtype.NewBool(), + mu: rwmutex.Create(safe...), array: make([]interface{}, 0, cap), comparator: comparator, } @@ -75,7 +73,7 @@ func NewSortedArrayFrom(array []interface{}, comparator func(a, b interface{}) i a := NewSortedArraySize(0, comparator, safe...) a.array = array sort.Slice(a.array, func(i, j int) bool { - return a.comparator(a.array[i], a.array[j]) < 0 + return a.getComparator()(a.array[i], a.array[j]) < 0 }) return a } @@ -95,19 +93,19 @@ func (a *SortedArray) SetArray(array []interface{}) *SortedArray { defer a.mu.Unlock() a.array = array sort.Slice(a.array, func(i, j int) bool { - return a.comparator(a.array[i], a.array[j]) < 0 + return a.getComparator()(a.array[i], a.array[j]) < 0 }) return a } // SetComparator sets/changes the comparator for sorting. +// It resorts the array as the comparator is changed. func (a *SortedArray) SetComparator(comparator func(a, b interface{}) int) { a.mu.Lock() defer a.mu.Unlock() a.comparator = comparator - // Resort the array if comparator is changed. sort.Slice(a.array, func(i, j int) bool { - return a.comparator(a.array[i], a.array[j]) < 0 + return a.getComparator()(a.array[i], a.array[j]) < 0 }) } @@ -118,7 +116,7 @@ func (a *SortedArray) Sort() *SortedArray { a.mu.Lock() defer a.mu.Unlock() sort.Slice(a.array, func(i, j int) bool { - return a.comparator(a.array[i], a.array[j]) < 0 + return a.getComparator()(a.array[i], a.array[j]) < 0 }) return a } @@ -132,7 +130,7 @@ func (a *SortedArray) Add(values ...interface{}) *SortedArray { defer a.mu.Unlock() for _, value := range values { index, cmp := a.binSearch(value, false) - if a.unique.Val() && cmp == 0 { + if a.unique && cmp == 0 { continue } if index < 0 { @@ -390,7 +388,7 @@ func (a *SortedArray) Len() int { // Note that, if it's in concurrent-safe usage, it returns a copy of underlying data, // or else a pointer to the underlying data. func (a *SortedArray) Slice() []interface{} { - array := ([]interface{})(nil) + var array []interface{} if a.mu.IsSafe() { a.mu.RLock() defer a.mu.RUnlock() @@ -439,8 +437,8 @@ func (a *SortedArray) binSearch(value interface{}, lock bool) (index int, result mid := 0 cmp := -2 for min <= max { - mid = int((min + max) / 2) - cmp = a.comparator(value, a.array[mid]) + mid = (min + max) / 2 + cmp = a.getComparator()(value, a.array[mid]) switch { case cmp < 0: max = mid - 1 @@ -457,8 +455,8 @@ func (a *SortedArray) binSearch(value interface{}, lock bool) (index int, result // which means it does not contain any repeated items. // It also do unique check, remove all repeated items. func (a *SortedArray) SetUnique(unique bool) *SortedArray { - oldUnique := a.unique.Val() - a.unique.Set(unique) + oldUnique := a.unique + a.unique = unique if unique && oldUnique != unique { a.Unique() } @@ -477,7 +475,7 @@ func (a *SortedArray) Unique() *SortedArray { if i == len(a.array)-1 { break } - if a.comparator(a.array[i], a.array[i+1]) == 0 { + if a.getComparator()(a.array[i], a.array[i+1]) == 0 { a.array = append(a.array[:i+1], a.array[i+1+1:]...) } else { i++ @@ -509,6 +507,12 @@ func (a *SortedArray) Clear() *SortedArray { func (a *SortedArray) LockFunc(f func(array []interface{})) *SortedArray { a.mu.Lock() defer a.mu.Unlock() + + // Keep the array always sorted. + defer sort.Slice(a.array, func(i, j int) bool { + return a.getComparator()(a.array[i], a.array[j]) < 0 + }) + f(a.array) return a } @@ -526,23 +530,7 @@ func (a *SortedArray) RLockFunc(f func(array []interface{})) *SortedArray { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *SortedArray) Merge(array interface{}) *SortedArray { - switch v := array.(type) { - case *Array: - a.Add(gconv.Interfaces(v.Slice())...) - case *IntArray: - a.Add(gconv.Interfaces(v.Slice())...) - case *StrArray: - a.Add(gconv.Interfaces(v.Slice())...) - case *SortedArray: - a.Add(gconv.Interfaces(v.Slice())...) - case *SortedIntArray: - a.Add(gconv.Interfaces(v.Slice())...) - case *SortedStrArray: - a.Add(gconv.Interfaces(v.Slice())...) - default: - a.Add(gconv.Interfaces(array)...) - } - return a + return a.Add(gconv.Interfaces(array)...) } // Chunk splits an array into multiple arrays, @@ -596,6 +584,9 @@ func (a *SortedArray) Rands(size int) []interface{} { func (a *SortedArray) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(gconv.String(v)) @@ -622,7 +613,7 @@ func (a *SortedArray) Iterator(f func(k int, v interface{}) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedArray) IteratorAsc(f func(k int, v interface{}) bool) { a.mu.RLock() @@ -634,7 +625,7 @@ func (a *SortedArray) IteratorAsc(f func(k int, v interface{}) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedArray) IteratorDesc(f func(k int, v interface{}) bool) { a.mu.RLock() @@ -676,12 +667,10 @@ func (a *SortedArray) MarshalJSON() ([]byte, error) { } // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. +// Note that the comparator is set as string comparator in default. func (a *SortedArray) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.comparator == nil { a.array = make([]interface{}, 0) - a.unique = gtype.NewBool() - // Note that the comparator is string comparator in default. a.comparator = gutil.ComparatorString } a.mu.Lock() @@ -698,11 +687,9 @@ func (a *SortedArray) UnmarshalJSON(b []byte) error { } // UnmarshalValue is an interface implement which sets any type of value for array. +// Note that the comparator is set as string comparator in default. func (a *SortedArray) UnmarshalValue(value interface{}) (err error) { - if a.mu == nil { - a.mu = rwmutex.New() - a.unique = gtype.NewBool() - // Note that the comparator is string comparator in default. + if a.comparator == nil { a.comparator = gutil.ComparatorString } a.mu.Lock() @@ -764,7 +751,30 @@ func (a *SortedArray) FilterEmpty() *SortedArray { return a } +// Walk applies a user supplied function to every item of array. +func (a *SortedArray) Walk(f func(value interface{}) interface{}) *SortedArray { + a.mu.Lock() + defer a.mu.Unlock() + // Keep the array always sorted. + defer sort.Slice(a.array, func(i, j int) bool { + return a.getComparator()(a.array[i], a.array[j]) < 0 + }) + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *SortedArray) IsEmpty() bool { return a.Len() == 0 } + +// getComparator returns the comparator if it's previously set, +// or else it panics. +func (a *SortedArray) getComparator() func(a, b interface{}) int { + if a.comparator == nil { + panic("comparator is missing for sorted array") + } + return a.comparator +} diff --git a/container/garray/garray_sorted_int.go b/container/garray/garray_sorted_int.go index 182f3f211..10a94fbaa 100644 --- a/container/garray/garray_sorted_int.go +++ b/container/garray/garray_sorted_int.go @@ -13,7 +13,6 @@ import ( "math" "sort" - "github.com/gogf/gf/container/gtype" "github.com/gogf/gf/internal/rwmutex" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/grand" @@ -22,9 +21,9 @@ import ( // SortedIntArray is a golang sorted int array with rich features. // It's using increasing order in default. type SortedIntArray struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []int - unique *gtype.Bool // Whether enable unique feature(false) + unique bool // Whether enable unique feature(false) comparator func(a, b int) int // Comparison function(it returns -1: a < b; 0: a == b; 1: a > b) } @@ -48,9 +47,8 @@ func NewSortedIntArrayComparator(comparator func(a, b int) int, safe ...bool) *S // which is false in default. func NewSortedIntArraySize(cap int, safe ...bool) *SortedIntArray { return &SortedIntArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: make([]int, 0, cap), - unique: gtype.NewBool(), comparator: defaultComparatorInt, } } @@ -94,7 +92,7 @@ func (a *SortedIntArray) SetArray(array []int) *SortedIntArray { a.mu.Lock() defer a.mu.Unlock() a.array = array - sort.Ints(a.array) + quickSortInt(a.array, a.getComparator()) return a } @@ -104,7 +102,7 @@ func (a *SortedIntArray) SetArray(array []int) *SortedIntArray { func (a *SortedIntArray) Sort() *SortedIntArray { a.mu.Lock() defer a.mu.Unlock() - sort.Ints(a.array) + quickSortInt(a.array, a.getComparator()) return a } @@ -117,7 +115,7 @@ func (a *SortedIntArray) Add(values ...int) *SortedIntArray { defer a.mu.Unlock() for _, value := range values { index, cmp := a.binSearch(value, false) - if a.unique.Val() && cmp == 0 { + if a.unique && cmp == 0 { continue } if index < 0 { @@ -436,8 +434,8 @@ func (a *SortedIntArray) binSearch(value int, lock bool) (index int, result int) mid := 0 cmp := -2 for min <= max { - mid = int((min + max) / 2) - cmp = a.comparator(value, a.array[mid]) + mid = (min + max) / 2 + cmp = a.getComparator()(value, a.array[mid]) switch { case cmp < 0: max = mid - 1 @@ -454,8 +452,8 @@ func (a *SortedIntArray) binSearch(value int, lock bool) (index int, result int) // which means it does not contain any repeated items. // It also do unique check, remove all repeated items. func (a *SortedIntArray) SetUnique(unique bool) *SortedIntArray { - oldUnique := a.unique.Val() - a.unique.Set(unique) + oldUnique := a.unique + a.unique = unique if unique && oldUnique != unique { a.Unique() } @@ -474,7 +472,7 @@ func (a *SortedIntArray) Unique() *SortedIntArray { if i == len(a.array)-1 { break } - if a.comparator(a.array[i], a.array[i+1]) == 0 { + if a.getComparator()(a.array[i], a.array[i+1]) == 0 { a.array = append(a.array[:i+1], a.array[i+1+1:]...) } else { i++ @@ -523,23 +521,7 @@ func (a *SortedIntArray) RLockFunc(f func(array []int)) *SortedIntArray { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *SortedIntArray) Merge(array interface{}) *SortedIntArray { - switch v := array.(type) { - case *Array: - a.Add(gconv.Ints(v.Slice())...) - case *IntArray: - a.Add(gconv.Ints(v.Slice())...) - case *StrArray: - a.Add(gconv.Ints(v.Slice())...) - case *SortedArray: - a.Add(gconv.Ints(v.Slice())...) - case *SortedIntArray: - a.Add(gconv.Ints(v.Slice())...) - case *SortedStrArray: - a.Add(gconv.Ints(v.Slice())...) - default: - a.Add(gconv.Ints(array)...) - } - return a + return a.Add(gconv.Ints(array)...) } // Chunk splits an array into multiple arrays, @@ -593,6 +575,9 @@ func (a *SortedIntArray) Rands(size int) []int { func (a *SortedIntArray) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(gconv.String(v)) @@ -619,7 +604,7 @@ func (a *SortedIntArray) Iterator(f func(k int, v int) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedIntArray) IteratorAsc(f func(k int, v int) bool) { a.mu.RLock() @@ -631,7 +616,7 @@ func (a *SortedIntArray) IteratorAsc(f func(k int, v int) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedIntArray) IteratorDesc(f func(k int, v int) bool) { a.mu.RLock() @@ -657,10 +642,8 @@ func (a *SortedIntArray) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (a *SortedIntArray) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.comparator == nil { a.array = make([]int, 0) - a.unique = gtype.NewBool() a.comparator = defaultComparatorInt } a.mu.Lock() @@ -676,10 +659,7 @@ func (a *SortedIntArray) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for array. func (a *SortedIntArray) UnmarshalValue(value interface{}) (err error) { - if a.mu == nil { - a.mu = rwmutex.New() - a.unique = gtype.NewBool() - // Note that the comparator is string comparator in default. + if a.comparator == nil { a.comparator = defaultComparatorInt } a.mu.Lock() @@ -717,7 +697,30 @@ func (a *SortedIntArray) FilterEmpty() *SortedIntArray { return a } +// Walk applies a user supplied function to every item of array. +func (a *SortedIntArray) Walk(f func(value int) int) *SortedIntArray { + a.mu.Lock() + defer a.mu.Unlock() + + // Keep the array always sorted. + defer quickSortInt(a.array, a.getComparator()) + + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *SortedIntArray) IsEmpty() bool { return a.Len() == 0 } + +// getComparator returns the comparator if it's previously set, +// or else it returns a default comparator. +func (a *SortedIntArray) getComparator() func(a, b int) int { + if a.comparator == nil { + return defaultComparatorInt + } + return a.comparator +} diff --git a/container/garray/garray_sorted_str.go b/container/garray/garray_sorted_str.go index aed2ad100..61ad49b1d 100644 --- a/container/garray/garray_sorted_str.go +++ b/container/garray/garray_sorted_str.go @@ -13,7 +13,6 @@ import ( "math" "sort" - "github.com/gogf/gf/container/gtype" "github.com/gogf/gf/internal/rwmutex" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/grand" @@ -22,9 +21,9 @@ import ( // SortedStrArray is a golang sorted string array with rich features. // It's using increasing order in default. type SortedStrArray struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex array []string - unique *gtype.Bool // Whether enable unique feature(false) + unique bool // Whether enable unique feature(false) comparator func(a, b string) int // Comparison function(it returns -1: a < b; 0: a == b; 1: a > b) } @@ -48,9 +47,8 @@ func NewSortedStrArrayComparator(comparator func(a, b string) int, safe ...bool) // which is false in default. func NewSortedStrArraySize(cap int, safe ...bool) *SortedStrArray { return &SortedStrArray{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), array: make([]string, 0, cap), - unique: gtype.NewBool(), comparator: defaultComparatorStr, } } @@ -61,7 +59,7 @@ func NewSortedStrArraySize(cap int, safe ...bool) *SortedStrArray { func NewSortedStrArrayFrom(array []string, safe ...bool) *SortedStrArray { a := NewSortedStrArraySize(0, safe...) a.array = array - quickSortStr(a.array, a.comparator) + quickSortStr(a.array, a.getComparator()) return a } @@ -79,7 +77,7 @@ func (a *SortedStrArray) SetArray(array []string) *SortedStrArray { a.mu.Lock() defer a.mu.Unlock() a.array = array - quickSortStr(a.array, a.comparator) + quickSortStr(a.array, a.getComparator()) return a } @@ -89,7 +87,7 @@ func (a *SortedStrArray) SetArray(array []string) *SortedStrArray { func (a *SortedStrArray) Sort() *SortedStrArray { a.mu.Lock() defer a.mu.Unlock() - quickSortStr(a.array, a.comparator) + quickSortStr(a.array, a.getComparator()) return a } @@ -102,7 +100,7 @@ func (a *SortedStrArray) Add(values ...string) *SortedStrArray { defer a.mu.Unlock() for _, value := range values { index, cmp := a.binSearch(value, false) - if a.unique.Val() && cmp == 0 { + if a.unique && cmp == 0 { continue } if index < 0 { @@ -421,8 +419,8 @@ func (a *SortedStrArray) binSearch(value string, lock bool) (index int, result i mid := 0 cmp := -2 for min <= max { - mid = int((min + max) / 2) - cmp = a.comparator(value, a.array[mid]) + mid = (min + max) / 2 + cmp = a.getComparator()(value, a.array[mid]) switch { case cmp < 0: max = mid - 1 @@ -439,8 +437,8 @@ func (a *SortedStrArray) binSearch(value string, lock bool) (index int, result i // which means it does not contain any repeated items. // It also do unique check, remove all repeated items. func (a *SortedStrArray) SetUnique(unique bool) *SortedStrArray { - oldUnique := a.unique.Val() - a.unique.Set(unique) + oldUnique := a.unique + a.unique = unique if unique && oldUnique != unique { a.Unique() } @@ -459,7 +457,7 @@ func (a *SortedStrArray) Unique() *SortedStrArray { if i == len(a.array)-1 { break } - if a.comparator(a.array[i], a.array[i+1]) == 0 { + if a.getComparator()(a.array[i], a.array[i+1]) == 0 { a.array = append(a.array[:i+1], a.array[i+1+1:]...) } else { i++ @@ -508,23 +506,7 @@ func (a *SortedStrArray) RLockFunc(f func(array []string)) *SortedStrArray { // The difference between Merge and Append is Append supports only specified slice type, // but Merge supports more parameter types. func (a *SortedStrArray) Merge(array interface{}) *SortedStrArray { - switch v := array.(type) { - case *Array: - a.Add(gconv.Strings(v.Slice())...) - case *IntArray: - a.Add(gconv.Strings(v.Slice())...) - case *StrArray: - a.Add(gconv.Strings(v.Slice())...) - case *SortedArray: - a.Add(gconv.Strings(v.Slice())...) - case *SortedIntArray: - a.Add(gconv.Strings(v.Slice())...) - case *SortedStrArray: - a.Add(gconv.Strings(v.Slice())...) - default: - a.Add(gconv.Strings(array)...) - } - return a + return a.Add(gconv.Strings(array)...) } // Chunk splits an array into multiple arrays, @@ -578,6 +560,9 @@ func (a *SortedStrArray) Rands(size int) []string { func (a *SortedStrArray) Join(glue string) string { a.mu.RLock() defer a.mu.RUnlock() + if len(a.array) == 0 { + return "" + } buffer := bytes.NewBuffer(nil) for k, v := range a.array { buffer.WriteString(v) @@ -604,7 +589,7 @@ func (a *SortedStrArray) Iterator(f func(k int, v string) bool) { a.IteratorAsc(f) } -// IteratorAsc iterates the array in ascending order with given callback function . +// IteratorAsc iterates the array readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedStrArray) IteratorAsc(f func(k int, v string) bool) { a.mu.RLock() @@ -616,7 +601,7 @@ func (a *SortedStrArray) IteratorAsc(f func(k int, v string) bool) { } } -// IteratorDesc iterates the array in descending order with given callback function . +// IteratorDesc iterates the array readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (a *SortedStrArray) IteratorDesc(f func(k int, v string) bool) { a.mu.RLock() @@ -653,10 +638,8 @@ func (a *SortedStrArray) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (a *SortedStrArray) UnmarshalJSON(b []byte) error { - if a.mu == nil { - a.mu = rwmutex.New() + if a.comparator == nil { a.array = make([]string, 0) - a.unique = gtype.NewBool() a.comparator = defaultComparatorStr } a.mu.Lock() @@ -672,10 +655,7 @@ func (a *SortedStrArray) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for array. func (a *SortedStrArray) UnmarshalValue(value interface{}) (err error) { - if a.mu == nil { - a.mu = rwmutex.New() - a.unique = gtype.NewBool() - // Note that the comparator is string comparator in default. + if a.comparator == nil { a.comparator = defaultComparatorStr } a.mu.Lock() @@ -713,7 +693,30 @@ func (a *SortedStrArray) FilterEmpty() *SortedStrArray { return a } +// Walk applies a user supplied function to every item of array. +func (a *SortedStrArray) Walk(f func(value string) string) *SortedStrArray { + a.mu.Lock() + defer a.mu.Unlock() + + // Keep the array always sorted. + defer quickSortStr(a.array, a.getComparator()) + + for i, v := range a.array { + a.array[i] = f(v) + } + return a +} + // IsEmpty checks whether the array is empty. func (a *SortedStrArray) IsEmpty() bool { return a.Len() == 0 } + +// getComparator returns the comparator if it's previously set, +// or else it returns a default comparator. +func (a *SortedStrArray) getComparator() func(a, b string) int { + if a.comparator == nil { + return defaultComparatorStr + } + return a.comparator +} diff --git a/container/garray/garray_z_unit_all_basic_test.go b/container/garray/garray_z_unit_all_basic_test.go index faeaa960d..b9da2fc95 100644 --- a/container/garray/garray_z_unit_all_basic_test.go +++ b/container/garray/garray_z_unit_all_basic_test.go @@ -9,6 +9,7 @@ package garray_test import ( + "github.com/gogf/gf/util/gutil" "strings" "testing" @@ -17,6 +18,55 @@ import ( "github.com/gogf/gf/util/gconv" ) +func Test_Array_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var array garray.Array + expect := []int{2, 3, 1} + array.Append(2, 3, 1) + t.Assert(array.Slice(), expect) + }) + gtest.C(t, func(t *gtest.T) { + var array garray.IntArray + expect := []int{2, 3, 1} + array.Append(2, 3, 1) + t.Assert(array.Slice(), expect) + }) + gtest.C(t, func(t *gtest.T) { + var array garray.StrArray + expect := []string{"b", "a"} + array.Append("b", "a") + t.Assert(array.Slice(), expect) + }) + gtest.C(t, func(t *gtest.T) { + var array garray.SortedArray + array.SetComparator(gutil.ComparatorInt) + expect := []int{1, 2, 3} + array.Add(2, 3, 1) + t.Assert(array.Slice(), expect) + }) + gtest.C(t, func(t *gtest.T) { + var array garray.SortedIntArray + expect := []int{1, 2, 3} + array.Add(2, 3, 1) + t.Assert(array.Slice(), expect) + }) + gtest.C(t, func(t *gtest.T) { + var array garray.SortedStrArray + expect := []string{"a", "b", "c"} + array.Add("c", "a", "b") + t.Assert(array.Slice(), expect) + }) +} + +func Test_SortedIntArray_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var array garray.SortedIntArray + expect := []int{1, 2, 3} + array.Add(2, 3, 1) + t.Assert(array.Slice(), expect) + }) +} + func Test_IntArray_Unique(t *testing.T) { gtest.C(t, func(t *gtest.T) { expect := []int{1, 2, 3, 4, 5, 6} diff --git a/container/garray/garray_z_unit_normal_any_array_test.go b/container/garray/garray_z_unit_normal_any_array_test.go index a1447c8ab..3cd7df5ff 100644 --- a/container/garray/garray_z_unit_normal_any_array_test.go +++ b/container/garray/garray_z_unit_normal_any_array_test.go @@ -650,3 +650,12 @@ func TestArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.Slice{1, 2, 3, 4}) }) } + +func TestArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewArrayFrom(g.Slice{"1", "2"}) + t.Assert(array.Walk(func(value interface{}) interface{} { + return "key-" + gconv.String(value) + }), g.Slice{"key-1", "key-2"}) + }) +} diff --git a/container/garray/garray_z_unit_normal_int_array_test.go b/container/garray/garray_z_unit_normal_int_array_test.go index 112de3db7..9ee47c380 100644 --- a/container/garray/garray_z_unit_normal_int_array_test.go +++ b/container/garray/garray_z_unit_normal_int_array_test.go @@ -683,3 +683,12 @@ func TestIntArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.SliceInt{1, 2, 3, 4}) }) } + +func TestIntArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewIntArrayFrom(g.SliceInt{1, 2}) + t.Assert(array.Walk(func(value int) int { + return 10 + value + }), g.Slice{11, 12}) + }) +} diff --git a/container/garray/garray_z_unit_normal_str_array_test.go b/container/garray/garray_z_unit_normal_str_array_test.go index 090705224..d0065f9df 100644 --- a/container/garray/garray_z_unit_normal_str_array_test.go +++ b/container/garray/garray_z_unit_normal_str_array_test.go @@ -671,3 +671,12 @@ func TestStrArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.SliceStr{"1", "2"}) }) } + +func TestStrArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewStrArrayFrom(g.SliceStr{"1", "2"}) + t.Assert(array.Walk(func(value string) string { + return "key-" + value + }), g.Slice{"key-1", "key-2"}) + }) +} diff --git a/container/garray/garray_z_unit_sorted_any_array_test.go b/container/garray/garray_z_unit_sorted_any_array_test.go index bd639235b..7003dd458 100644 --- a/container/garray/garray_z_unit_sorted_any_array_test.go +++ b/container/garray/garray_z_unit_sorted_any_array_test.go @@ -779,3 +779,12 @@ func TestSortedArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.Slice{1, 2, 3, 4}) }) } + +func TestSortedArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewSortedArrayFrom(g.Slice{"1", "2"}, gutil.ComparatorString) + t.Assert(array.Walk(func(value interface{}) interface{} { + return "key-" + gconv.String(value) + }), g.Slice{"key-1", "key-2"}) + }) +} diff --git a/container/garray/garray_z_unit_sorted_int_array_test.go b/container/garray/garray_z_unit_sorted_int_array_test.go index 4265ddb08..f195fcd63 100644 --- a/container/garray/garray_z_unit_sorted_int_array_test.go +++ b/container/garray/garray_z_unit_sorted_int_array_test.go @@ -642,3 +642,12 @@ func TestSortedIntArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.SliceInt{1, 2, 3, 4}) }) } + +func TestSortedIntArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewSortedIntArrayFrom(g.SliceInt{1, 2}) + t.Assert(array.Walk(func(value int) int { + return 10 + value + }), g.Slice{11, 12}) + }) +} diff --git a/container/garray/garray_z_unit_sorted_str_array_test.go b/container/garray/garray_z_unit_sorted_str_array_test.go index 0981add81..002c4e3b5 100644 --- a/container/garray/garray_z_unit_sorted_str_array_test.go +++ b/container/garray/garray_z_unit_sorted_str_array_test.go @@ -652,3 +652,12 @@ func TestSortedStrArray_FilterEmpty(t *testing.T) { t.Assert(array.FilterEmpty(), g.SliceStr{"1", "2"}) }) } + +func TestSortedStrArray_Walk(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := garray.NewSortedStrArrayFrom(g.SliceStr{"1", "2"}) + t.Assert(array.Walk(func(value string) string { + return "key-" + value + }), g.Slice{"key-1", "key-2"}) + }) +} diff --git a/container/glist/glist.go b/container/glist/glist.go index 9e8dce8c1..6bdd3af9d 100644 --- a/container/glist/glist.go +++ b/container/glist/glist.go @@ -20,8 +20,8 @@ import ( type ( List struct { - mu *rwmutex.RWMutex - list *list.List + mu rwmutex.RWMutex + list list.List } Element = list.Element @@ -30,8 +30,8 @@ type ( // New creates and returns a new empty doubly linked list. func New(safe ...bool) *List { return &List{ - mu: rwmutex.New(safe...), - list: list.New(), + mu: rwmutex.Create(safe...), + list: list.List{}, } } @@ -39,12 +39,12 @@ func New(safe ...bool) *List { // The parameter is used to specify whether using list in concurrent-safety, // which is false in default. func NewFrom(array []interface{}, safe ...bool) *List { - l := list.New() + l := list.List{} for _, v := range array { l.PushBack(v) } return &List{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), list: l, } } @@ -273,7 +273,7 @@ func (l *List) PushBackList(other *List) { defer other.mu.RUnlock() } l.mu.Lock() - l.list.PushBackList(other.list) + l.list.PushBackList(&other.list) l.mu.Unlock() } @@ -285,7 +285,7 @@ func (l *List) PushFrontList(other *List) { defer other.mu.RUnlock() } l.mu.Lock() - l.list.PushFrontList(other.list) + l.list.PushFrontList(&other.list) l.mu.Unlock() } @@ -332,7 +332,7 @@ func (l *List) Removes(es []*Element) { // RemoveAll removes all elements from list . func (l *List) RemoveAll() { l.mu.Lock() - l.list = list.New() + l.list = list.List{} l.mu.Unlock() } @@ -345,14 +345,14 @@ func (l *List) Clear() { func (l *List) RLockFunc(f func(list *list.List)) { l.mu.RLock() defer l.mu.RUnlock() - f(l.list) + f(&l.list) } // LockFunc locks writing with given callback function within RWMutex.Lock. func (l *List) LockFunc(f func(list *list.List)) { l.mu.Lock() defer l.mu.Unlock() - f(l.list) + f(&l.list) } // Iterator is alias of IteratorAsc. @@ -360,7 +360,7 @@ func (l *List) Iterator(f func(e *Element) bool) { l.IteratorAsc(f) } -// IteratorAsc iterates the list in ascending order with given callback function . +// IteratorAsc iterates the list readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (l *List) IteratorAsc(f func(e *Element) bool) { l.mu.RLock() @@ -375,7 +375,7 @@ func (l *List) IteratorAsc(f func(e *Element) bool) { } } -// IteratorDesc iterates the list in descending order with given callback function . +// IteratorDesc iterates the list readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (l *List) IteratorDesc(f func(e *Element) bool) { l.mu.RLock() @@ -425,10 +425,6 @@ func (l *List) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (l *List) UnmarshalJSON(b []byte) error { - if l.mu == nil { - l.mu = rwmutex.New() - l.list = list.New() - } l.mu.Lock() defer l.mu.Unlock() var array []interface{} @@ -441,10 +437,6 @@ func (l *List) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for list. func (l *List) UnmarshalValue(value interface{}) (err error) { - if l.mu == nil { - l.mu = rwmutex.New() - l.list = list.New() - } l.mu.Lock() defer l.mu.Unlock() var array []interface{} diff --git a/container/glist/glist_z_unit_test.go b/container/glist/glist_z_unit_test.go index 9d145d796..cd8b662cc 100644 --- a/container/glist/glist_z_unit_test.go +++ b/container/glist/glist_z_unit_test.go @@ -41,6 +41,44 @@ func checkListPointers(t *gtest.T, l *List, es []*Element) { }) } +func TestVar(t *testing.T) { + var l List + l.PushFront(1) + l.PushFront(2) + if v := l.PopBack(); v != 1 { + t.Errorf("EXPECT %v, GOT %v", 1, v) + } else { + //fmt.Println(v) + } + if v := l.PopBack(); v != 2 { + t.Errorf("EXPECT %v, GOT %v", 2, v) + } else { + //fmt.Println(v) + } + if v := l.PopBack(); v != nil { + t.Errorf("EXPECT %v, GOT %v", nil, v) + } else { + //fmt.Println(v) + } + l.PushBack(1) + l.PushBack(2) + if v := l.PopFront(); v != 1 { + t.Errorf("EXPECT %v, GOT %v", 1, v) + } else { + //fmt.Println(v) + } + if v := l.PopFront(); v != 2 { + t.Errorf("EXPECT %v, GOT %v", 2, v) + } else { + //fmt.Println(v) + } + if v := l.PopFront(); v != nil { + t.Errorf("EXPECT %v, GOT %v", nil, v) + } else { + //fmt.Println(v) + } +} + func TestBasic(t *testing.T) { l := New() l.PushFront(1) diff --git a/container/gmap/gmap.go b/container/gmap/gmap.go index 1108262b1..4d0d28183 100644 --- a/container/gmap/gmap.go +++ b/container/gmap/gmap.go @@ -7,9 +7,10 @@ // Package gmap provides concurrent-safe/unsafe map containers. package gmap -// Map based on hash table, alias of AnyAnyMap. -type Map = AnyAnyMap -type HashMap = AnyAnyMap +type ( + Map = AnyAnyMap // Map is alias of AnyAnyMap. + HashMap = AnyAnyMap // HashMap is alias of AnyAnyMap. +) // New creates and returns an empty hash map. // The parameter is used to specify whether using map in concurrent-safety, diff --git a/container/gmap/gmap_hash_any_any_map.go b/container/gmap/gmap_hash_any_any_map.go index 2a2b3c279..a26c486c7 100644 --- a/container/gmap/gmap_hash_any_any_map.go +++ b/container/gmap/gmap_hash_any_any_map.go @@ -18,7 +18,7 @@ import ( ) type AnyAnyMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[interface{}]interface{} } @@ -27,7 +27,7 @@ type AnyAnyMap struct { // which is false in default. func NewAnyAnyMap(safe ...bool) *AnyAnyMap { return &AnyAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[interface{}]interface{}), } } @@ -37,12 +37,12 @@ func NewAnyAnyMap(safe ...bool) *AnyAnyMap { // there might be some concurrent-safe issues when changing the map outside. func NewAnyAnyMapFrom(data map[interface{}]interface{}, safe ...bool) *AnyAnyMap { return &AnyAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *AnyAnyMap) Iterator(f func(k interface{}, v interface{}) bool) { m.mu.RLock() @@ -89,37 +89,44 @@ func (m *AnyAnyMap) MapCopy() map[interface{}]interface{} { // MapStrAny returns a copy of the underlying data of the map as map[string]interface{}. func (m *AnyAnyMap) MapStrAny() map[string]interface{} { m.mu.RLock() + defer m.mu.RUnlock() data := make(map[string]interface{}, len(m.data)) for k, v := range m.data { data[gconv.String(k)] = v } - m.mu.RUnlock() return data } // FilterEmpty deletes all key-value pair of which the value is empty. func (m *AnyAnyMap) FilterEmpty() { m.mu.Lock() + defer m.mu.Unlock() for k, v := range m.data { if empty.IsEmpty(v) { delete(m.data, k) } } - m.mu.Unlock() } // Set sets key-value to the hash map. -func (m *AnyAnyMap) Set(key interface{}, val interface{}) { +func (m *AnyAnyMap) Set(key interface{}, value interface{}) { m.mu.Lock() - m.data[key] = val + if m.data == nil { + m.data = make(map[interface{}]interface{}) + } + m.data[key] = value m.mu.Unlock() } // Sets batch sets key-values to the hash map. func (m *AnyAnyMap) Sets(data map[interface{}]interface{}) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -128,17 +135,21 @@ func (m *AnyAnyMap) Sets(data map[interface{}]interface{}) { // Second return parameter is true if key was found, otherwise false. func (m *AnyAnyMap) Search(key interface{}) (value interface{}, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *AnyAnyMap) Get(key interface{}) interface{} { +func (m *AnyAnyMap) Get(key interface{}) (value interface{}) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -163,8 +174,10 @@ func (m *AnyAnyMap) Pops(size int) map[interface{}]interface{} { if size == 0 { return nil } - index := 0 - newMap := make(map[interface{}]interface{}, size) + var ( + index = 0 + newMap = make(map[interface{}]interface{}, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -188,6 +201,9 @@ func (m *AnyAnyMap) Pops(size int) map[interface{}]interface{} { func (m *AnyAnyMap) doSetWithLockCheck(key interface{}, value interface{}) interface{} { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[interface{}]interface{}) + } if v, ok := m.data[key]; ok { return v } @@ -235,26 +251,26 @@ func (m *AnyAnyMap) GetOrSetFuncLock(key interface{}, f func() interface{}) inte } } -// GetVar returns a gvar.Var with the value by given . -// The returned gvar.Var is un-concurrent safe. +// GetVar returns a Var with the value by given . +// The returned Var is un-concurrent safe. func (m *AnyAnyMap) GetVar(key interface{}) *gvar.Var { return gvar.New(m.Get(key)) } -// GetVarOrSet returns a gvar.Var with result from GetVarOrSet. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSet returns a Var with result from GetVarOrSet. +// The returned Var is un-concurrent safe. func (m *AnyAnyMap) GetVarOrSet(key interface{}, value interface{}) *gvar.Var { return gvar.New(m.GetOrSet(key, value)) } -// GetVarOrSetFunc returns a gvar.Var with result from GetOrSetFunc. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFunc returns a Var with result from GetOrSetFunc. +// The returned Var is un-concurrent safe. func (m *AnyAnyMap) GetVarOrSetFunc(key interface{}, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFunc(key, f)) } -// GetVarOrSetFuncLock returns a gvar.Var with result from GetOrSetFuncLock. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFuncLock returns a Var with result from GetOrSetFuncLock. +// The returned Var is un-concurrent safe. func (m *AnyAnyMap) GetVarOrSetFuncLock(key interface{}, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFuncLock(key, f)) } @@ -293,21 +309,25 @@ func (m *AnyAnyMap) SetIfNotExistFuncLock(key interface{}, f func() interface{}) } // Remove deletes value from map by given , and return this deleted value. -func (m *AnyAnyMap) Remove(key interface{}) interface{} { +func (m *AnyAnyMap) Remove(key interface{}) (value interface{}) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Removes batch deletes values of the map by keys. func (m *AnyAnyMap) Removes(keys []interface{}) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } @@ -315,36 +335,43 @@ func (m *AnyAnyMap) Removes(keys []interface{}) { // Keys returns all keys of the map as a slice. func (m *AnyAnyMap) Keys() []interface{} { m.mu.RLock() - keys := make([]interface{}, len(m.data)) - index := 0 + defer m.mu.RUnlock() + var ( + keys = make([]interface{}, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ } - m.mu.RUnlock() return keys } // Values returns all values of the map as a slice. func (m *AnyAnyMap) Values() []interface{} { m.mu.RLock() - values := make([]interface{}, len(m.data)) - index := 0 + defer m.mu.RUnlock() + var ( + values = make([]interface{}, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ } - m.mu.RUnlock() return values } // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *AnyAnyMap) Contains(key interface{}) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -405,6 +432,10 @@ func (m *AnyAnyMap) Flip() { func (m *AnyAnyMap) Merge(other *AnyAnyMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -421,12 +452,11 @@ func (m *AnyAnyMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *AnyAnyMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[interface{}]interface{}) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[interface{}]interface{}) + } var data map[string]interface{} if err := json.Unmarshal(b, &data); err != nil { return err @@ -439,12 +469,11 @@ func (m *AnyAnyMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *AnyAnyMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[interface{}]interface{}) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[interface{}]interface{}) + } for k, v := range gconv.Map(value) { m.data[k] = v } diff --git a/container/gmap/gmap_hash_int_any_map.go b/container/gmap/gmap_hash_int_any_map.go index 8c3479add..cb72e812c 100644 --- a/container/gmap/gmap_hash_int_any_map.go +++ b/container/gmap/gmap_hash_int_any_map.go @@ -18,7 +18,7 @@ import ( ) type IntAnyMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[int]interface{} } @@ -27,7 +27,7 @@ type IntAnyMap struct { // which is false in default. func NewIntAnyMap(safe ...bool) *IntAnyMap { return &IntAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[int]interface{}), } } @@ -37,12 +37,12 @@ func NewIntAnyMap(safe ...bool) *IntAnyMap { // there might be some concurrent-safe issues when changing the map outside. func NewIntAnyMapFrom(data map[int]interface{}, safe ...bool) *IntAnyMap { return &IntAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *IntAnyMap) Iterator(f func(k int, v interface{}) bool) { m.mu.RLock() @@ -111,6 +111,9 @@ func (m *IntAnyMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *IntAnyMap) Set(key int, val interface{}) { m.mu.Lock() + if m.data == nil { + m.data = make(map[int]interface{}) + } m.data[key] = val m.mu.Unlock() } @@ -118,8 +121,12 @@ func (m *IntAnyMap) Set(key int, val interface{}) { // Sets batch sets key-values to the hash map. func (m *IntAnyMap) Sets(data map[int]interface{}) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -128,17 +135,21 @@ func (m *IntAnyMap) Sets(data map[int]interface{}) { // Second return parameter is true if key was found, otherwise false. func (m *IntAnyMap) Search(key int) (value interface{}, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *IntAnyMap) Get(key int) interface{} { +func (m *IntAnyMap) Get(key int) (value interface{}) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -163,8 +174,10 @@ func (m *IntAnyMap) Pops(size int) map[int]interface{} { if size == 0 { return nil } - index := 0 - newMap := make(map[int]interface{}, size) + var ( + index = 0 + newMap = make(map[int]interface{}, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -188,6 +201,9 @@ func (m *IntAnyMap) Pops(size int) map[int]interface{} { func (m *IntAnyMap) doSetWithLockCheck(key int, value interface{}) interface{} { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]interface{}) + } if v, ok := m.data[key]; ok { return v } @@ -233,26 +249,26 @@ func (m *IntAnyMap) GetOrSetFuncLock(key int, f func() interface{}) interface{} } } -// GetVar returns a gvar.Var with the value by given . -// The returned gvar.Var is un-concurrent safe. +// GetVar returns a Var with the value by given . +// The returned Var is un-concurrent safe. func (m *IntAnyMap) GetVar(key int) *gvar.Var { return gvar.New(m.Get(key)) } -// GetVarOrSet returns a gvar.Var with result from GetVarOrSet. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSet returns a Var with result from GetVarOrSet. +// The returned Var is un-concurrent safe. func (m *IntAnyMap) GetVarOrSet(key int, value interface{}) *gvar.Var { return gvar.New(m.GetOrSet(key, value)) } -// GetVarOrSetFunc returns a gvar.Var with result from GetOrSetFunc. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFunc returns a Var with result from GetOrSetFunc. +// The returned Var is un-concurrent safe. func (m *IntAnyMap) GetVarOrSetFunc(key int, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFunc(key, f)) } -// GetVarOrSetFuncLock returns a gvar.Var with result from GetOrSetFuncLock. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFuncLock returns a Var with result from GetOrSetFuncLock. +// The returned Var is un-concurrent safe. func (m *IntAnyMap) GetVarOrSetFuncLock(key int, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFuncLock(key, f)) } @@ -293,28 +309,34 @@ func (m *IntAnyMap) SetIfNotExistFuncLock(key int, f func() interface{}) bool { // Removes batch deletes values of the map by keys. func (m *IntAnyMap) Removes(keys []int) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *IntAnyMap) Remove(key int) interface{} { +func (m *IntAnyMap) Remove(key int) (value interface{}) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *IntAnyMap) Keys() []int { m.mu.RLock() - keys := make([]int, len(m.data)) - index := 0 + var ( + keys = make([]int, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -326,8 +348,10 @@ func (m *IntAnyMap) Keys() []int { // Values returns all values of the map as a slice. func (m *IntAnyMap) Values() []interface{} { m.mu.RLock() - values := make([]interface{}, len(m.data)) - index := 0 + var ( + values = make([]interface{}, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -339,10 +363,13 @@ func (m *IntAnyMap) Values() []interface{} { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *IntAnyMap) Contains(key int) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -403,6 +430,10 @@ func (m *IntAnyMap) Flip() { func (m *IntAnyMap) Merge(other *IntAnyMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -421,12 +452,11 @@ func (m *IntAnyMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *IntAnyMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]interface{}) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]interface{}) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -435,12 +465,11 @@ func (m *IntAnyMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *IntAnyMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]interface{}) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]interface{}) + } switch value.(type) { case string, []byte: return json.Unmarshal(gconv.Bytes(value), &m.data) diff --git a/container/gmap/gmap_hash_int_int_map.go b/container/gmap/gmap_hash_int_int_map.go index 826ac3d59..eb6449831 100644 --- a/container/gmap/gmap_hash_int_int_map.go +++ b/container/gmap/gmap_hash_int_int_map.go @@ -16,7 +16,7 @@ import ( ) type IntIntMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[int]int } @@ -25,7 +25,7 @@ type IntIntMap struct { // which is false in default. func NewIntIntMap(safe ...bool) *IntIntMap { return &IntIntMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[int]int), } } @@ -35,12 +35,12 @@ func NewIntIntMap(safe ...bool) *IntIntMap { // there might be some concurrent-safe issues when changing the map outside. func NewIntIntMapFrom(data map[int]int, safe ...bool) *IntIntMap { return &IntIntMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *IntIntMap) Iterator(f func(k int, v int) bool) { m.mu.RLock() @@ -109,6 +109,9 @@ func (m *IntIntMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *IntIntMap) Set(key int, val int) { m.mu.Lock() + if m.data == nil { + m.data = make(map[int]int) + } m.data[key] = val m.mu.Unlock() } @@ -116,8 +119,12 @@ func (m *IntIntMap) Set(key int, val int) { // Sets batch sets key-values to the hash map. func (m *IntIntMap) Sets(data map[int]int) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -126,17 +133,21 @@ func (m *IntIntMap) Sets(data map[int]int) { // Second return parameter is true if key was found, otherwise false. func (m *IntIntMap) Search(key int) (value int, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *IntIntMap) Get(key int) int { +func (m *IntIntMap) Get(key int) (value int) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -161,8 +172,10 @@ func (m *IntIntMap) Pops(size int) map[int]int { if size == 0 { return nil } - index := 0 - newMap := make(map[int]int, size) + var ( + index = 0 + newMap = make(map[int]int, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -181,12 +194,14 @@ func (m *IntIntMap) Pops(size int) map[int]int { // It returns value with given . func (m *IntIntMap) doSetWithLockCheck(key int, value int) int { m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]int) + } if v, ok := m.data[key]; ok { - m.mu.Unlock() return v } m.data[key] = value - m.mu.Unlock() return value } @@ -219,6 +234,9 @@ func (m *IntIntMap) GetOrSetFuncLock(key int, f func() int) int { if v, ok := m.Search(key); !ok { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]int) + } if v, ok = m.data[key]; ok { return v } @@ -259,6 +277,9 @@ func (m *IntIntMap) SetIfNotExistFuncLock(key int, f func() int) bool { if !m.Contains(key) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]int) + } if _, ok := m.data[key]; !ok { m.data[key] = f() } @@ -270,28 +291,34 @@ func (m *IntIntMap) SetIfNotExistFuncLock(key int, f func() int) bool { // Removes batch deletes values of the map by keys. func (m *IntIntMap) Removes(keys []int) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *IntIntMap) Remove(key int) int { +func (m *IntIntMap) Remove(key int) (value int) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *IntIntMap) Keys() []int { m.mu.RLock() - keys := make([]int, len(m.data)) - index := 0 + var ( + keys = make([]int, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -303,8 +330,10 @@ func (m *IntIntMap) Keys() []int { // Values returns all values of the map as a slice. func (m *IntIntMap) Values() []int { m.mu.RLock() - values := make([]int, len(m.data)) - index := 0 + var ( + values = make([]int, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -316,10 +345,13 @@ func (m *IntIntMap) Values() []int { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *IntIntMap) Contains(key int) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -380,6 +412,10 @@ func (m *IntIntMap) Flip() { func (m *IntIntMap) Merge(other *IntIntMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -398,12 +434,11 @@ func (m *IntIntMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *IntIntMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]int) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]int) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -412,12 +447,11 @@ func (m *IntIntMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *IntIntMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]int) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]int) + } switch value.(type) { case string, []byte: return json.Unmarshal(gconv.Bytes(value), &m.data) diff --git a/container/gmap/gmap_hash_int_str_map.go b/container/gmap/gmap_hash_int_str_map.go index 5db6b23f8..842d288e8 100644 --- a/container/gmap/gmap_hash_int_str_map.go +++ b/container/gmap/gmap_hash_int_str_map.go @@ -16,7 +16,7 @@ import ( ) type IntStrMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[int]string } @@ -25,7 +25,7 @@ type IntStrMap struct { // which is false in default. func NewIntStrMap(safe ...bool) *IntStrMap { return &IntStrMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[int]string), } } @@ -35,12 +35,12 @@ func NewIntStrMap(safe ...bool) *IntStrMap { // there might be some concurrent-safe issues when changing the map outside. func NewIntStrMapFrom(data map[int]string, safe ...bool) *IntStrMap { return &IntStrMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *IntStrMap) Iterator(f func(k int, v string) bool) { m.mu.RLock() @@ -109,6 +109,9 @@ func (m *IntStrMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *IntStrMap) Set(key int, val string) { m.mu.Lock() + if m.data == nil { + m.data = make(map[int]string) + } m.data[key] = val m.mu.Unlock() } @@ -116,8 +119,12 @@ func (m *IntStrMap) Set(key int, val string) { // Sets batch sets key-values to the hash map. func (m *IntStrMap) Sets(data map[int]string) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -126,17 +133,21 @@ func (m *IntStrMap) Sets(data map[int]string) { // Second return parameter is true if key was found, otherwise false. func (m *IntStrMap) Search(key int) (value string, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *IntStrMap) Get(key int) string { +func (m *IntStrMap) Get(key int) (value string) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -161,8 +172,10 @@ func (m *IntStrMap) Pops(size int) map[int]string { if size == 0 { return nil } - index := 0 - newMap := make(map[int]string, size) + var ( + index = 0 + newMap = make(map[int]string, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -182,6 +195,9 @@ func (m *IntStrMap) Pops(size int) map[int]string { func (m *IntStrMap) doSetWithLockCheck(key int, value string) string { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]string) + } if v, ok := m.data[key]; ok { return v } @@ -218,13 +234,14 @@ func (m *IntStrMap) GetOrSetFuncLock(key int, f func() string) string { if v, ok := m.Search(key); !ok { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]string) + } if v, ok = m.data[key]; ok { return v } v = f() - if v != "" { - m.data[key] = v - } + m.data[key] = v return v } else { return v @@ -260,6 +277,9 @@ func (m *IntStrMap) SetIfNotExistFuncLock(key int, f func() string) bool { if !m.Contains(key) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]string) + } if _, ok := m.data[key]; !ok { m.data[key] = f() } @@ -271,28 +291,34 @@ func (m *IntStrMap) SetIfNotExistFuncLock(key int, f func() string) bool { // Removes batch deletes values of the map by keys. func (m *IntStrMap) Removes(keys []int) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *IntStrMap) Remove(key int) string { +func (m *IntStrMap) Remove(key int) (value string) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *IntStrMap) Keys() []int { m.mu.RLock() - keys := make([]int, len(m.data)) - index := 0 + var ( + keys = make([]int, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -304,8 +330,10 @@ func (m *IntStrMap) Keys() []int { // Values returns all values of the map as a slice. func (m *IntStrMap) Values() []string { m.mu.RLock() - values := make([]string, len(m.data)) - index := 0 + var ( + values = make([]string, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -317,10 +345,13 @@ func (m *IntStrMap) Values() []string { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *IntStrMap) Contains(key int) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -381,6 +412,10 @@ func (m *IntStrMap) Flip() { func (m *IntStrMap) Merge(other *IntStrMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -399,12 +434,11 @@ func (m *IntStrMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *IntStrMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]string) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]string) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -413,12 +447,11 @@ func (m *IntStrMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *IntStrMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[int]string) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[int]string) + } switch value.(type) { case string, []byte: return json.Unmarshal(gconv.Bytes(value), &m.data) diff --git a/container/gmap/gmap_hash_str_any_map.go b/container/gmap/gmap_hash_str_any_map.go index cd7925c08..08b466018 100644 --- a/container/gmap/gmap_hash_str_any_map.go +++ b/container/gmap/gmap_hash_str_any_map.go @@ -18,7 +18,7 @@ import ( ) type StrAnyMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[string]interface{} } @@ -27,7 +27,7 @@ type StrAnyMap struct { // which is false in default. func NewStrAnyMap(safe ...bool) *StrAnyMap { return &StrAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[string]interface{}), } } @@ -37,12 +37,12 @@ func NewStrAnyMap(safe ...bool) *StrAnyMap { // there might be some concurrent-safe issues when changing the map outside. func NewStrAnyMapFrom(data map[string]interface{}, safe ...bool) *StrAnyMap { return &StrAnyMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *StrAnyMap) Iterator(f func(k string, v interface{}) bool) { m.mu.RLock() @@ -105,6 +105,9 @@ func (m *StrAnyMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *StrAnyMap) Set(key string, val interface{}) { m.mu.Lock() + if m.data == nil { + m.data = make(map[string]interface{}) + } m.data[key] = val m.mu.Unlock() } @@ -112,8 +115,12 @@ func (m *StrAnyMap) Set(key string, val interface{}) { // Sets batch sets key-values to the hash map. func (m *StrAnyMap) Sets(data map[string]interface{}) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -122,17 +129,21 @@ func (m *StrAnyMap) Sets(data map[string]interface{}) { // Second return parameter is true if key was found, otherwise false. func (m *StrAnyMap) Search(key string) (value interface{}, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *StrAnyMap) Get(key string) interface{} { +func (m *StrAnyMap) Get(key string) (value interface{}) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -157,8 +168,10 @@ func (m *StrAnyMap) Pops(size int) map[string]interface{} { if size == 0 { return nil } - index := 0 - newMap := make(map[string]interface{}, size) + var ( + index = 0 + newMap = make(map[string]interface{}, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -182,6 +195,9 @@ func (m *StrAnyMap) Pops(size int) map[string]interface{} { func (m *StrAnyMap) doSetWithLockCheck(key string, value interface{}) interface{} { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]interface{}) + } if v, ok := m.data[key]; ok { return v } @@ -229,26 +245,26 @@ func (m *StrAnyMap) GetOrSetFuncLock(key string, f func() interface{}) interface } } -// GetVar returns a gvar.Var with the value by given . -// The returned gvar.Var is un-concurrent safe. +// GetVar returns a Var with the value by given . +// The returned Var is un-concurrent safe. func (m *StrAnyMap) GetVar(key string) *gvar.Var { return gvar.New(m.Get(key)) } -// GetVarOrSet returns a gvar.Var with result from GetVarOrSet. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSet returns a Var with result from GetVarOrSet. +// The returned Var is un-concurrent safe. func (m *StrAnyMap) GetVarOrSet(key string, value interface{}) *gvar.Var { return gvar.New(m.GetOrSet(key, value)) } -// GetVarOrSetFunc returns a gvar.Var with result from GetOrSetFunc. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFunc returns a Var with result from GetOrSetFunc. +// The returned Var is un-concurrent safe. func (m *StrAnyMap) GetVarOrSetFunc(key string, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFunc(key, f)) } -// GetVarOrSetFuncLock returns a gvar.Var with result from GetOrSetFuncLock. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFuncLock returns a Var with result from GetOrSetFuncLock. +// The returned Var is un-concurrent safe. func (m *StrAnyMap) GetVarOrSetFuncLock(key string, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFuncLock(key, f)) } @@ -289,28 +305,34 @@ func (m *StrAnyMap) SetIfNotExistFuncLock(key string, f func() interface{}) bool // Removes batch deletes values of the map by keys. func (m *StrAnyMap) Removes(keys []string) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *StrAnyMap) Remove(key string) interface{} { +func (m *StrAnyMap) Remove(key string) (value interface{}) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *StrAnyMap) Keys() []string { m.mu.RLock() - keys := make([]string, len(m.data)) - index := 0 + var ( + keys = make([]string, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -322,8 +344,10 @@ func (m *StrAnyMap) Keys() []string { // Values returns all values of the map as a slice. func (m *StrAnyMap) Values() []interface{} { m.mu.RLock() - values := make([]interface{}, len(m.data)) - index := 0 + var ( + values = make([]interface{}, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -335,10 +359,13 @@ func (m *StrAnyMap) Values() []interface{} { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *StrAnyMap) Contains(key string) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -399,6 +426,10 @@ func (m *StrAnyMap) Flip() { func (m *StrAnyMap) Merge(other *StrAnyMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -417,12 +448,11 @@ func (m *StrAnyMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *StrAnyMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[string]interface{}) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]interface{}) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -431,10 +461,6 @@ func (m *StrAnyMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *StrAnyMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[string]interface{}) - } m.mu.Lock() defer m.mu.Unlock() m.data = gconv.Map(value) diff --git a/container/gmap/gmap_hash_str_int_map.go b/container/gmap/gmap_hash_str_int_map.go index 2f98237c3..df19a1296 100644 --- a/container/gmap/gmap_hash_str_int_map.go +++ b/container/gmap/gmap_hash_str_int_map.go @@ -16,7 +16,7 @@ import ( ) type StrIntMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[string]int } @@ -25,7 +25,7 @@ type StrIntMap struct { // which is false in default. func NewStrIntMap(safe ...bool) *StrIntMap { return &StrIntMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[string]int), } } @@ -35,12 +35,12 @@ func NewStrIntMap(safe ...bool) *StrIntMap { // there might be some concurrent-safe issues when changing the map outside. func NewStrIntMapFrom(data map[string]int, safe ...bool) *StrIntMap { return &StrIntMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *StrIntMap) Iterator(f func(k string, v int) bool) { m.mu.RLock() @@ -76,11 +76,11 @@ func (m *StrIntMap) Map() map[string]int { // MapStrAny returns a copy of the underlying data of the map as map[string]interface{}. func (m *StrIntMap) MapStrAny() map[string]interface{} { m.mu.RLock() + defer m.mu.RUnlock() data := make(map[string]interface{}, len(m.data)) for k, v := range m.data { data[k] = v } - m.mu.RUnlock() return data } @@ -109,6 +109,9 @@ func (m *StrIntMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *StrIntMap) Set(key string, val int) { m.mu.Lock() + if m.data == nil { + m.data = make(map[string]int) + } m.data[key] = val m.mu.Unlock() } @@ -116,8 +119,12 @@ func (m *StrIntMap) Set(key string, val int) { // Sets batch sets key-values to the hash map. func (m *StrIntMap) Sets(data map[string]int) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -126,17 +133,21 @@ func (m *StrIntMap) Sets(data map[string]int) { // Second return parameter is true if key was found, otherwise false. func (m *StrIntMap) Search(key string) (value int, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *StrIntMap) Get(key string) int { +func (m *StrIntMap) Get(key string) (value int) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -161,8 +172,10 @@ func (m *StrIntMap) Pops(size int) map[string]int { if size == 0 { return nil } - index := 0 - newMap := make(map[string]int, size) + var ( + index = 0 + newMap = make(map[string]int, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -181,6 +194,9 @@ func (m *StrIntMap) Pops(size int) map[string]int { // It returns value with given . func (m *StrIntMap) doSetWithLockCheck(key string, value int) int { m.mu.Lock() + if m.data == nil { + m.data = make(map[string]int) + } if v, ok := m.data[key]; ok { m.mu.Unlock() return v @@ -221,6 +237,9 @@ func (m *StrIntMap) GetOrSetFuncLock(key string, f func() int) int { if v, ok := m.Search(key); !ok { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]int) + } if v, ok = m.data[key]; ok { return v } @@ -261,6 +280,9 @@ func (m *StrIntMap) SetIfNotExistFuncLock(key string, f func() int) bool { if !m.Contains(key) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]int) + } if _, ok := m.data[key]; !ok { m.data[key] = f() } @@ -272,28 +294,34 @@ func (m *StrIntMap) SetIfNotExistFuncLock(key string, f func() int) bool { // Removes batch deletes values of the map by keys. func (m *StrIntMap) Removes(keys []string) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *StrIntMap) Remove(key string) int { +func (m *StrIntMap) Remove(key string) (value int) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *StrIntMap) Keys() []string { m.mu.RLock() - keys := make([]string, len(m.data)) - index := 0 + var ( + keys = make([]string, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -305,8 +333,10 @@ func (m *StrIntMap) Keys() []string { // Values returns all values of the map as a slice. func (m *StrIntMap) Values() []int { m.mu.RLock() - values := make([]int, len(m.data)) - index := 0 + var ( + values = make([]int, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -318,10 +348,13 @@ func (m *StrIntMap) Values() []int { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *StrIntMap) Contains(key string) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -382,6 +415,10 @@ func (m *StrIntMap) Flip() { func (m *StrIntMap) Merge(other *StrIntMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -400,12 +437,11 @@ func (m *StrIntMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *StrIntMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[string]int) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]int) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -414,12 +450,11 @@ func (m *StrIntMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *StrIntMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[string]int) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]int) + } switch value.(type) { case string, []byte: return json.Unmarshal(gconv.Bytes(value), &m.data) diff --git a/container/gmap/gmap_hash_str_str_map.go b/container/gmap/gmap_hash_str_str_map.go index 669b48033..c1ed1de01 100644 --- a/container/gmap/gmap_hash_str_str_map.go +++ b/container/gmap/gmap_hash_str_str_map.go @@ -17,7 +17,7 @@ import ( ) type StrStrMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[string]string } @@ -27,7 +27,7 @@ type StrStrMap struct { func NewStrStrMap(safe ...bool) *StrStrMap { return &StrStrMap{ data: make(map[string]string), - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), } } @@ -36,12 +36,12 @@ func NewStrStrMap(safe ...bool) *StrStrMap { // there might be some concurrent-safe issues when changing the map outside. func NewStrStrMapFrom(data map[string]string, safe ...bool) *StrStrMap { return &StrStrMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: data, } } -// Iterator iterates the hash map with custom callback function . +// Iterator iterates the hash map readonly with custom callback function . // If returns true, then it continues iterating; or false to stop. func (m *StrStrMap) Iterator(f func(k string, v string) bool) { m.mu.RLock() @@ -110,6 +110,9 @@ func (m *StrStrMap) FilterEmpty() { // Set sets key-value to the hash map. func (m *StrStrMap) Set(key string, val string) { m.mu.Lock() + if m.data == nil { + m.data = make(map[string]string) + } m.data[key] = val m.mu.Unlock() } @@ -117,8 +120,12 @@ func (m *StrStrMap) Set(key string, val string) { // Sets batch sets key-values to the hash map. func (m *StrStrMap) Sets(data map[string]string) { m.mu.Lock() - for k, v := range data { - m.data[k] = v + if m.data == nil { + m.data = data + } else { + for k, v := range data { + m.data[k] = v + } } m.mu.Unlock() } @@ -127,17 +134,21 @@ func (m *StrStrMap) Sets(data map[string]string) { // Second return parameter is true if key was found, otherwise false. func (m *StrStrMap) Search(key string) (value string, found bool) { m.mu.RLock() - value, found = m.data[key] + if m.data != nil { + value, found = m.data[key] + } m.mu.RUnlock() return } // Get returns the value by given . -func (m *StrStrMap) Get(key string) string { +func (m *StrStrMap) Get(key string) (value string) { m.mu.RLock() - val, _ := m.data[key] + if m.data != nil { + value, _ = m.data[key] + } m.mu.RUnlock() - return val + return } // Pop retrieves and deletes an item from the map. @@ -162,8 +173,10 @@ func (m *StrStrMap) Pops(size int) map[string]string { if size == 0 { return nil } - index := 0 - newMap := make(map[string]string, size) + var ( + index = 0 + newMap = make(map[string]string, size) + ) for k, v := range m.data { delete(m.data, k) newMap[k] = v @@ -183,6 +196,9 @@ func (m *StrStrMap) Pops(size int) map[string]string { func (m *StrStrMap) doSetWithLockCheck(key string, value string) string { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]string) + } if v, ok := m.data[key]; ok { return v } @@ -221,13 +237,14 @@ func (m *StrStrMap) GetOrSetFuncLock(key string, f func() string) string { if v, ok := m.Search(key); !ok { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]string) + } if v, ok = m.data[key]; ok { return v } v = f() - if v != "" { - m.data[key] = v - } + m.data[key] = v return v } else { return v @@ -263,6 +280,9 @@ func (m *StrStrMap) SetIfNotExistFuncLock(key string, f func() string) bool { if !m.Contains(key) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]string) + } if _, ok := m.data[key]; !ok { m.data[key] = f() } @@ -274,28 +294,34 @@ func (m *StrStrMap) SetIfNotExistFuncLock(key string, f func() string) bool { // Removes batch deletes values of the map by keys. func (m *StrStrMap) Removes(keys []string) { m.mu.Lock() - for _, key := range keys { - delete(m.data, key) + if m.data != nil { + for _, key := range keys { + delete(m.data, key) + } } m.mu.Unlock() } // Remove deletes value from map by given , and return this deleted value. -func (m *StrStrMap) Remove(key string) string { +func (m *StrStrMap) Remove(key string) (value string) { m.mu.Lock() - val, exists := m.data[key] - if exists { - delete(m.data, key) + if m.data != nil { + var ok bool + if value, ok = m.data[key]; ok { + delete(m.data, key) + } } m.mu.Unlock() - return val + return } // Keys returns all keys of the map as a slice. func (m *StrStrMap) Keys() []string { m.mu.RLock() - keys := make([]string, len(m.data)) - index := 0 + var ( + keys = make([]string, len(m.data)) + index = 0 + ) for key := range m.data { keys[index] = key index++ @@ -307,8 +333,10 @@ func (m *StrStrMap) Keys() []string { // Values returns all values of the map as a slice. func (m *StrStrMap) Values() []string { m.mu.RLock() - values := make([]string, len(m.data)) - index := 0 + var ( + values = make([]string, len(m.data)) + index = 0 + ) for _, value := range m.data { values[index] = value index++ @@ -320,10 +348,13 @@ func (m *StrStrMap) Values() []string { // Contains checks whether a key exists. // It returns true if the exists, or else false. func (m *StrStrMap) Contains(key string) bool { + var ok bool m.mu.RLock() - _, exists := m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() - return exists + return ok } // Size returns the size of the map. @@ -384,6 +415,10 @@ func (m *StrStrMap) Flip() { func (m *StrStrMap) Merge(other *StrStrMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = other.MapCopy() + return + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -402,12 +437,11 @@ func (m *StrStrMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *StrStrMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() - m.data = make(map[string]string) - } m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[string]string) + } if err := json.Unmarshal(b, &m.data); err != nil { return err } @@ -416,9 +450,6 @@ func (m *StrStrMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *StrStrMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() - } m.mu.Lock() defer m.mu.Unlock() m.data = gconv.MapStrStr(value) diff --git a/container/gmap/gmap_list_map.go b/container/gmap/gmap_list_map.go index 349a50f56..9a6e97ae2 100644 --- a/container/gmap/gmap_list_map.go +++ b/container/gmap/gmap_list_map.go @@ -19,7 +19,7 @@ import ( ) type ListMap struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[interface{}]*glist.Element list *glist.List } @@ -35,7 +35,7 @@ type gListMapNode struct { // which is false in default. func NewListMap(safe ...bool) *ListMap { return &ListMap{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[interface{}]*glist.Element), list: glist.New(), } @@ -55,28 +55,32 @@ func (m *ListMap) Iterator(f func(key, value interface{}) bool) { m.IteratorAsc(f) } -// IteratorAsc iterates the map in ascending order with given callback function . +// IteratorAsc iterates the map readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (m *ListMap) IteratorAsc(f func(key interface{}, value interface{}) bool) { m.mu.RLock() defer m.mu.RUnlock() - node := (*gListMapNode)(nil) - m.list.IteratorAsc(func(e *glist.Element) bool { - node = e.Value.(*gListMapNode) - return f(node.key, node.value) - }) + if m.list != nil { + node := (*gListMapNode)(nil) + m.list.IteratorAsc(func(e *glist.Element) bool { + node = e.Value.(*gListMapNode) + return f(node.key, node.value) + }) + } } -// IteratorDesc iterates the map in descending order with given callback function . +// IteratorDesc iterates the map readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (m *ListMap) IteratorDesc(f func(key interface{}, value interface{}) bool) { m.mu.RLock() defer m.mu.RUnlock() - node := (*gListMapNode)(nil) - m.list.IteratorDesc(func(e *glist.Element) bool { - node = e.Value.(*gListMapNode) - return f(node.key, node.value) - }) + if m.list != nil { + node := (*gListMapNode)(nil) + m.list.IteratorDesc(func(e *glist.Element) bool { + node = e.Value.(*gListMapNode) + return f(node.key, node.value) + }) + } } // Clone returns a new link map with copy of current map data. @@ -110,13 +114,16 @@ func (m *ListMap) Replace(data map[interface{}]interface{}) { // Map returns a copy of the underlying data of the map. func (m *ListMap) Map() map[interface{}]interface{} { m.mu.RLock() - node := (*gListMapNode)(nil) - data := make(map[interface{}]interface{}, len(m.data)) - m.list.IteratorAsc(func(e *glist.Element) bool { - node = e.Value.(*gListMapNode) - data[node.key] = node.value - return true - }) + var node *gListMapNode + var data map[interface{}]interface{} + if m.list != nil { + data = make(map[interface{}]interface{}, len(m.data)) + m.list.IteratorAsc(func(e *glist.Element) bool { + node = e.Value.(*gListMapNode) + data[node.key] = node.value + return true + }) + } m.mu.RUnlock() return data } @@ -124,13 +131,16 @@ func (m *ListMap) Map() map[interface{}]interface{} { // MapStrAny returns a copy of the underlying data of the map as map[string]interface{}. func (m *ListMap) MapStrAny() map[string]interface{} { m.mu.RLock() - node := (*gListMapNode)(nil) - data := make(map[string]interface{}, len(m.data)) - m.list.IteratorAsc(func(e *glist.Element) bool { - node = e.Value.(*gListMapNode) - data[gconv.String(node.key)] = node.value - return true - }) + var node *gListMapNode + var data map[string]interface{} + if m.list != nil { + data = make(map[string]interface{}, len(m.data)) + m.list.IteratorAsc(func(e *glist.Element) bool { + node = e.Value.(*gListMapNode) + data[gconv.String(node.key)] = node.value + return true + }) + } m.mu.RUnlock() return data } @@ -138,20 +148,22 @@ func (m *ListMap) MapStrAny() map[string]interface{} { // FilterEmpty deletes all key-value pair of which the value is empty. func (m *ListMap) FilterEmpty() { m.mu.Lock() - keys := make([]interface{}, 0) - node := (*gListMapNode)(nil) - m.list.IteratorAsc(func(e *glist.Element) bool { - node = e.Value.(*gListMapNode) - if empty.IsEmpty(node.value) { - keys = append(keys, node.key) - } - return true - }) - if len(keys) > 0 { - for _, key := range keys { - if e, ok := m.data[key]; ok { - delete(m.data, key) - m.list.Remove(e) + if m.list != nil { + keys := make([]interface{}, 0) + node := (*gListMapNode)(nil) + m.list.IteratorAsc(func(e *glist.Element) bool { + node = e.Value.(*gListMapNode) + if empty.IsEmpty(node.value) { + keys = append(keys, node.key) + } + return true + }) + if len(keys) > 0 { + for _, key := range keys { + if e, ok := m.data[key]; ok { + delete(m.data, key) + m.list.Remove(e) + } } } } @@ -161,6 +173,10 @@ func (m *ListMap) FilterEmpty() { // Set sets key-value to the map. func (m *ListMap) Set(key interface{}, value interface{}) { m.mu.Lock() + if m.data == nil { + m.data = make(map[interface{}]*glist.Element) + m.list = glist.New() + } if e, ok := m.data[key]; !ok { m.data[key] = m.list.PushBack(&gListMapNode{key, value}) } else { @@ -172,6 +188,10 @@ func (m *ListMap) Set(key interface{}, value interface{}) { // Sets batch sets key-values to the map. func (m *ListMap) Sets(data map[interface{}]interface{}) { m.mu.Lock() + if m.data == nil { + m.data = make(map[interface{}]*glist.Element) + m.list = glist.New() + } for key, value := range data { if e, ok := m.data[key]; !ok { m.data[key] = m.list.PushBack(&gListMapNode{key, value}) @@ -186,9 +206,11 @@ func (m *ListMap) Sets(data map[interface{}]interface{}) { // Second return parameter is true if key was found, otherwise false. func (m *ListMap) Search(key interface{}) (value interface{}, found bool) { m.mu.RLock() - if e, ok := m.data[key]; ok { - value = e.Value.(*gListMapNode).value - found = ok + if m.data != nil { + if e, ok := m.data[key]; ok { + value = e.Value.(*gListMapNode).value + found = ok + } } m.mu.RUnlock() return @@ -197,8 +219,10 @@ func (m *ListMap) Search(key interface{}) (value interface{}, found bool) { // Get returns the value by given . func (m *ListMap) Get(key interface{}) (value interface{}) { m.mu.RLock() - if e, ok := m.data[key]; ok { - value = e.Value.(*gListMapNode).value + if m.data != nil { + if e, ok := m.data[key]; ok { + value = e.Value.(*gListMapNode).value + } } m.mu.RUnlock() return @@ -255,6 +279,10 @@ func (m *ListMap) Pops(size int) map[interface{}]interface{} { func (m *ListMap) doSetWithLockCheck(key interface{}, value interface{}) interface{} { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[interface{}]*glist.Element) + m.list = glist.New() + } if e, ok := m.data[key]; ok { return e.Value.(*gListMapNode).value } @@ -302,26 +330,26 @@ func (m *ListMap) GetOrSetFuncLock(key interface{}, f func() interface{}) interf } } -// GetVar returns a gvar.Var with the value by given . -// The returned gvar.Var is un-concurrent safe. +// GetVar returns a Var with the value by given . +// The returned Var is un-concurrent safe. func (m *ListMap) GetVar(key interface{}) *gvar.Var { return gvar.New(m.Get(key)) } -// GetVarOrSet returns a gvar.Var with result from GetVarOrSet. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSet returns a Var with result from GetVarOrSet. +// The returned Var is un-concurrent safe. func (m *ListMap) GetVarOrSet(key interface{}, value interface{}) *gvar.Var { return gvar.New(m.GetOrSet(key, value)) } -// GetVarOrSetFunc returns a gvar.Var with result from GetOrSetFunc. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFunc returns a Var with result from GetOrSetFunc. +// The returned Var is un-concurrent safe. func (m *ListMap) GetVarOrSetFunc(key interface{}, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFunc(key, f)) } -// GetVarOrSetFuncLock returns a gvar.Var with result from GetOrSetFuncLock. -// The returned gvar.Var is un-concurrent safe. +// GetVarOrSetFuncLock returns a Var with result from GetOrSetFuncLock. +// The returned Var is un-concurrent safe. func (m *ListMap) GetVarOrSetFuncLock(key interface{}, f func() interface{}) *gvar.Var { return gvar.New(m.GetOrSetFuncLock(key, f)) } @@ -362,10 +390,12 @@ func (m *ListMap) SetIfNotExistFuncLock(key interface{}, f func() interface{}) b // Remove deletes value from map by given , and return this deleted value. func (m *ListMap) Remove(key interface{}) (value interface{}) { m.mu.Lock() - if e, ok := m.data[key]; ok { - value = e.Value.(*gListMapNode).value - delete(m.data, key) - m.list.Remove(e) + if m.data != nil { + if e, ok := m.data[key]; ok { + value = e.Value.(*gListMapNode).value + delete(m.data, key) + m.list.Remove(e) + } } m.mu.Unlock() return @@ -374,10 +404,12 @@ func (m *ListMap) Remove(key interface{}) (value interface{}) { // Removes batch deletes values of the map by keys. func (m *ListMap) Removes(keys []interface{}) { m.mu.Lock() - for _, key := range keys { - if e, ok := m.data[key]; ok { - delete(m.data, key) - m.list.Remove(e) + if m.data != nil { + for _, key := range keys { + if e, ok := m.data[key]; ok { + delete(m.data, key) + m.list.Remove(e) + } } } m.mu.Unlock() @@ -386,13 +418,17 @@ func (m *ListMap) Removes(keys []interface{}) { // Keys returns all keys of the map as a slice in ascending order. func (m *ListMap) Keys() []interface{} { m.mu.RLock() - keys := make([]interface{}, m.list.Len()) - index := 0 - m.list.IteratorAsc(func(e *glist.Element) bool { - keys[index] = e.Value.(*gListMapNode).key - index++ - return true - }) + var ( + keys = make([]interface{}, m.list.Len()) + index = 0 + ) + if m.list != nil { + m.list.IteratorAsc(func(e *glist.Element) bool { + keys[index] = e.Value.(*gListMapNode).key + index++ + return true + }) + } m.mu.RUnlock() return keys } @@ -400,13 +436,17 @@ func (m *ListMap) Keys() []interface{} { // Values returns all values of the map as a slice. func (m *ListMap) Values() []interface{} { m.mu.RLock() - values := make([]interface{}, m.list.Len()) - index := 0 - m.list.IteratorAsc(func(e *glist.Element) bool { - values[index] = e.Value.(*gListMapNode).value - index++ - return true - }) + var ( + values = make([]interface{}, m.list.Len()) + index = 0 + ) + if m.list != nil { + m.list.IteratorAsc(func(e *glist.Element) bool { + values[index] = e.Value.(*gListMapNode).value + index++ + return true + }) + } m.mu.RUnlock() return values } @@ -415,7 +455,9 @@ func (m *ListMap) Values() []interface{} { // It returns true if the exists, or else false. func (m *ListMap) Contains(key interface{}) (ok bool) { m.mu.RLock() - _, ok = m.data[key] + if m.data != nil { + _, ok = m.data[key] + } m.mu.RUnlock() return } @@ -448,6 +490,10 @@ func (m *ListMap) Flip() { func (m *ListMap) Merge(other *ListMap) { m.mu.Lock() defer m.mu.Unlock() + if m.data == nil { + m.data = make(map[interface{}]*glist.Element) + m.list = glist.New() + } if other != m { other.mu.RLock() defer other.mu.RUnlock() @@ -471,13 +517,12 @@ func (m *ListMap) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (m *ListMap) UnmarshalJSON(b []byte) error { - if m.mu == nil { - m.mu = rwmutex.New() + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { m.data = make(map[interface{}]*glist.Element) m.list = glist.New() } - m.mu.Lock() - defer m.mu.Unlock() var data map[string]interface{} if err := json.Unmarshal(b, &data); err != nil { return err @@ -494,13 +539,12 @@ func (m *ListMap) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (m *ListMap) UnmarshalValue(value interface{}) (err error) { - if m.mu == nil { - m.mu = rwmutex.New() + m.mu.Lock() + defer m.mu.Unlock() + if m.data == nil { m.data = make(map[interface{}]*glist.Element) m.list = glist.New() } - m.mu.Lock() - defer m.mu.Unlock() for k, v := range gconv.Map(value) { if e, ok := m.data[k]; !ok { m.data[k] = m.list.PushBack(&gListMapNode{k, v}) diff --git a/container/gmap/gmap_z_basic_test.go b/container/gmap/gmap_z_basic_test.go index 383a3f4f5..534ccc610 100644 --- a/container/gmap/gmap_z_basic_test.go +++ b/container/gmap/gmap_z_basic_test.go @@ -7,6 +7,7 @@ package gmap_test import ( + "github.com/gogf/gf/util/gutil" "testing" "github.com/gogf/gf/container/gmap" @@ -17,6 +18,55 @@ func getValue() interface{} { return 3 } +func Test_Map_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.Map + m.Set(1, 11) + t.Assert(m.Get(1), 11) + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.IntAnyMap + m.Set(1, 11) + t.Assert(m.Get(1), 11) + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.IntIntMap + m.Set(1, 11) + t.Assert(m.Get(1), 11) + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.IntStrMap + m.Set(1, "11") + t.Assert(m.Get(1), "11") + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.StrAnyMap + m.Set("1", "11") + t.Assert(m.Get("1"), "11") + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.StrStrMap + m.Set("1", "11") + t.Assert(m.Get("1"), "11") + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.StrIntMap + m.Set("1", 11) + t.Assert(m.Get("1"), 11) + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.ListMap + m.Set("1", 11) + t.Assert(m.Get("1"), 11) + }) + gtest.C(t, func(t *gtest.T) { + var m gmap.TreeMap + m.SetComparator(gutil.ComparatorString) + m.Set("1", 11) + t.Assert(m.Get("1"), 11) + }) +} + func Test_Map_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.New() diff --git a/container/gmap/gmap_z_unit_any_any_test.go b/container/gmap/gmap_z_unit_any_any_test.go index 37520c728..88fd41daa 100644 --- a/container/gmap/gmap_z_unit_any_any_test.go +++ b/container/gmap/gmap_z_unit_any_any_test.go @@ -17,8 +17,34 @@ import ( "github.com/gogf/gf/test/gtest" ) -func anyAnyCallBack(int, interface{}) bool { - return true +func Test_AnyAnyMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.AnyAnyMap + m.Set(1, 1) + + t.Assert(m.Get(1), 1) + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet(2, "2"), "2") + t.Assert(m.SetIfNotExist(2, "2"), false) + + t.Assert(m.SetIfNotExist(3, 3), true) + + t.Assert(m.Remove(2), "2") + t.Assert(m.Contains(2), false) + + t.AssertIN(3, m.Keys()) + t.AssertIN(1, m.Keys()) + t.AssertIN(3, m.Values()) + t.AssertIN(1, m.Values()) + m.Flip() + t.Assert(m.Map(), map[interface{}]int{1: 1, 3: 3}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } func Test_AnyAnyMap_Basic(t *testing.T) { diff --git a/container/gmap/gmap_z_unit_int_any_test.go b/container/gmap/gmap_z_unit_int_any_test.go index 9bf922ed7..125596992 100644 --- a/container/gmap/gmap_z_unit_int_any_test.go +++ b/container/gmap/gmap_z_unit_int_any_test.go @@ -20,9 +20,37 @@ import ( func getAny() interface{} { return 123 } -func intAnyCallBack(int, interface{}) bool { - return true + +func Test_IntAnyMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.IntAnyMap + m.Set(1, 1) + + t.Assert(m.Get(1), 1) + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet(2, "2"), "2") + t.Assert(m.SetIfNotExist(2, "2"), false) + + t.Assert(m.SetIfNotExist(3, 3), true) + + t.Assert(m.Remove(2), "2") + t.Assert(m.Contains(2), false) + + t.AssertIN(3, m.Keys()) + t.AssertIN(1, m.Keys()) + t.AssertIN(3, m.Values()) + t.AssertIN(1, m.Values()) + m.Flip() + t.Assert(m.Map(), map[interface{}]int{1: 1, 3: 3}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } + func Test_IntAnyMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntAnyMap() @@ -55,6 +83,7 @@ func Test_IntAnyMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[int]interface{}{1: 1, 2: "2"}) }) } + func Test_IntAnyMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntAnyMap() diff --git a/container/gmap/gmap_z_unit_int_int_test.go b/container/gmap/gmap_z_unit_int_int_test.go index 5a37ddcc2..ae7092c14 100644 --- a/container/gmap/gmap_z_unit_int_int_test.go +++ b/container/gmap/gmap_z_unit_int_int_test.go @@ -20,9 +20,41 @@ import ( func getInt() int { return 123 } + func intIntCallBack(int, int) bool { return true } + +func Test_IntIntMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.IntIntMap + m.Set(1, 1) + + t.Assert(m.Get(1), 1) + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet(2, 2), 2) + t.Assert(m.SetIfNotExist(2, 2), false) + + t.Assert(m.SetIfNotExist(3, 3), true) + + t.Assert(m.Remove(2), 2) + t.Assert(m.Contains(2), false) + + t.AssertIN(3, m.Keys()) + t.AssertIN(1, m.Keys()) + t.AssertIN(3, m.Values()) + t.AssertIN(1, m.Values()) + m.Flip() + t.Assert(m.Map(), map[int]int{1: 1, 3: 3}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) +} + func Test_IntIntMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntIntMap() @@ -55,6 +87,7 @@ func Test_IntIntMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[int]int{1: 1, 2: 2}) }) } + func Test_IntIntMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntIntMap() diff --git a/container/gmap/gmap_z_unit_int_str_test.go b/container/gmap/gmap_z_unit_int_str_test.go index 47622c24b..1312684ea 100644 --- a/container/gmap/gmap_z_unit_int_str_test.go +++ b/container/gmap/gmap_z_unit_int_str_test.go @@ -20,9 +20,40 @@ import ( func getStr() string { return "z" } -func intStrCallBack(int, string) bool { - return true + +func Test_IntStrMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.IntStrMap + m.Set(1, "a") + + t.Assert(m.Get(1), "a") + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet(2, "b"), "b") + t.Assert(m.SetIfNotExist(2, "b"), false) + + t.Assert(m.SetIfNotExist(3, "c"), true) + + t.Assert(m.Remove(2), "b") + t.Assert(m.Contains(2), false) + + t.AssertIN(3, m.Keys()) + t.AssertIN(1, m.Keys()) + t.AssertIN("a", m.Values()) + t.AssertIN("c", m.Values()) + + m_f := gmap.NewIntStrMap() + m_f.Set(1, "2") + m_f.Flip() + t.Assert(m_f.Map(), map[int]string{2: "1"}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } + func Test_IntStrMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntStrMap() @@ -60,6 +91,7 @@ func Test_IntStrMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[int]string{1: "a", 2: "b"}) }) } + func Test_IntStrMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewIntStrMap() diff --git a/container/gmap/gmap_z_unit_list_map_test.go b/container/gmap/gmap_z_unit_list_map_test.go index ec8feb4c2..564321c1c 100644 --- a/container/gmap/gmap_z_unit_list_map_test.go +++ b/container/gmap/gmap_z_unit_list_map_test.go @@ -17,6 +17,38 @@ import ( "github.com/gogf/gf/test/gtest" ) +func Test_ListMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.ListMap + m.Set("key1", "val1") + t.Assert(m.Keys(), []interface{}{"key1"}) + + t.Assert(m.Get("key1"), "val1") + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet("key2", "val2"), "val2") + t.Assert(m.SetIfNotExist("key2", "val2"), false) + + t.Assert(m.SetIfNotExist("key3", "val3"), true) + t.Assert(m.Remove("key2"), "val2") + t.Assert(m.Contains("key2"), false) + + t.AssertIN("key3", m.Keys()) + t.AssertIN("key1", m.Keys()) + t.AssertIN("val3", m.Values()) + t.AssertIN("val1", m.Values()) + + m.Flip() + + t.Assert(m.Map(), map[interface{}]interface{}{"val3": "key3", "val1": "key1"}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) +} + func Test_ListMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewListMap() @@ -51,6 +83,7 @@ func Test_ListMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[interface{}]interface{}{1: 1, "key1": "val1"}) }) } + func Test_ListMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewListMap() diff --git a/container/gmap/gmap_z_unit_str_any_test.go b/container/gmap/gmap_z_unit_str_any_test.go index 04965342f..a2159a38c 100644 --- a/container/gmap/gmap_z_unit_str_any_test.go +++ b/container/gmap/gmap_z_unit_str_any_test.go @@ -17,9 +17,37 @@ import ( "github.com/gogf/gf/test/gtest" ) -func stringAnyCallBack(string, interface{}) bool { - return true +func Test_StrAnyMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.StrAnyMap + m.Set("a", 1) + + t.Assert(m.Get("a"), 1) + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet("b", "2"), "2") + t.Assert(m.SetIfNotExist("b", "2"), false) + + t.Assert(m.SetIfNotExist("c", 3), true) + + t.Assert(m.Remove("b"), "2") + t.Assert(m.Contains("b"), false) + + t.AssertIN("c", m.Keys()) + t.AssertIN("a", m.Keys()) + t.AssertIN(3, m.Values()) + t.AssertIN(1, m.Values()) + + m.Flip() + t.Assert(m.Map(), map[string]interface{}{"1": "a", "3": "c"}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } + func Test_StrAnyMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrAnyMap() @@ -53,6 +81,7 @@ func Test_StrAnyMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[string]interface{}{"a": 1, "b": "2"}) }) } + func Test_StrAnyMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrAnyMap() diff --git a/container/gmap/gmap_z_unit_str_int_test.go b/container/gmap/gmap_z_unit_str_int_test.go index b20f23040..22a87e94f 100644 --- a/container/gmap/gmap_z_unit_str_int_test.go +++ b/container/gmap/gmap_z_unit_str_int_test.go @@ -17,9 +17,39 @@ import ( "github.com/gogf/gf/test/gtest" ) -func stringIntCallBack(string, int) bool { - return true +func Test_StrIntMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.StrIntMap + m.Set("a", 1) + + t.Assert(m.Get("a"), 1) + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet("b", 2), 2) + t.Assert(m.SetIfNotExist("b", 2), false) + + t.Assert(m.SetIfNotExist("c", 3), true) + + t.Assert(m.Remove("b"), 2) + t.Assert(m.Contains("b"), false) + + t.AssertIN("c", m.Keys()) + t.AssertIN("a", m.Keys()) + t.AssertIN(3, m.Values()) + t.AssertIN(1, m.Values()) + + m_f := gmap.NewStrIntMap() + m_f.Set("1", 2) + m_f.Flip() + t.Assert(m_f.Map(), map[string]int{"2": 1}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } + func Test_StrIntMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrIntMap() @@ -55,6 +85,7 @@ func Test_StrIntMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[string]int{"a": 1, "b": 2}) }) } + func Test_StrIntMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrIntMap() diff --git a/container/gmap/gmap_z_unit_str_str_test.go b/container/gmap/gmap_z_unit_str_str_test.go index f18ecc98e..d452b1648 100644 --- a/container/gmap/gmap_z_unit_str_str_test.go +++ b/container/gmap/gmap_z_unit_str_str_test.go @@ -17,9 +17,38 @@ import ( "github.com/gogf/gf/test/gtest" ) -func stringStrCallBack(string, string) bool { - return true +func Test_StrStrMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.StrStrMap + m.Set("a", "a") + + t.Assert(m.Get("a"), "a") + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet("b", "b"), "b") + t.Assert(m.SetIfNotExist("b", "b"), false) + + t.Assert(m.SetIfNotExist("c", "c"), true) + + t.Assert(m.Remove("b"), "b") + t.Assert(m.Contains("b"), false) + + t.AssertIN("c", m.Keys()) + t.AssertIN("a", m.Keys()) + t.AssertIN("a", m.Values()) + t.AssertIN("c", m.Values()) + + m.Flip() + + t.Assert(m.Map(), map[string]string{"a": "a", "c": "c"}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) } + func Test_StrStrMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrStrMap() @@ -54,6 +83,7 @@ func Test_StrStrMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[string]string{"a": "a", "b": "b"}) }) } + func Test_StrStrMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewStrStrMap() diff --git a/container/gmap/gmap_z_unit_tree_map_test.go b/container/gmap/gmap_z_unit_tree_map_test.go index 6e623722f..611482d31 100644 --- a/container/gmap/gmap_z_unit_tree_map_test.go +++ b/container/gmap/gmap_z_unit_tree_map_test.go @@ -17,6 +17,39 @@ import ( "github.com/gogf/gf/util/gutil" ) +func Test_TreeMap_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var m gmap.TreeMap + m.SetComparator(gutil.ComparatorString) + m.Set("key1", "val1") + t.Assert(m.Keys(), []interface{}{"key1"}) + + t.Assert(m.Get("key1"), "val1") + t.Assert(m.Size(), 1) + t.Assert(m.IsEmpty(), false) + + t.Assert(m.GetOrSet("key2", "val2"), "val2") + t.Assert(m.SetIfNotExist("key2", "val2"), false) + + t.Assert(m.SetIfNotExist("key3", "val3"), true) + + t.Assert(m.Remove("key2"), "val2") + t.Assert(m.Contains("key2"), false) + + t.AssertIN("key3", m.Keys()) + t.AssertIN("key1", m.Keys()) + t.AssertIN("val3", m.Values()) + t.AssertIN("val1", m.Values()) + + m.Flip() + t.Assert(m.Map(), map[interface{}]interface{}{"val3": "key3", "val1": "key1"}) + + m.Clear() + t.Assert(m.Size(), 0) + t.Assert(m.IsEmpty(), true) + }) +} + func Test_TreeMap_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewTreeMap(gutil.ComparatorString) @@ -51,6 +84,7 @@ func Test_TreeMap_Basic(t *testing.T) { t.Assert(m2.Map(), map[interface{}]interface{}{1: 1, "key1": "val1"}) }) } + func Test_TreeMap_Set_Fun(t *testing.T) { gtest.C(t, func(t *gtest.T) { m := gmap.NewTreeMap(gutil.ComparatorString) diff --git a/container/gring/gring.go b/container/gring/gring.go index 884be1159..fc75335a5 100644 --- a/container/gring/gring.go +++ b/container/gring/gring.go @@ -14,6 +14,7 @@ import ( "github.com/gogf/gf/internal/rwmutex" ) +// Ring is a struct of ring structure. type Ring struct { mu *rwmutex.RWMutex ring *ring.Ring // Underlying ring. @@ -22,6 +23,9 @@ type Ring struct { dirty *gtype.Bool // Dirty, which means the len and cap should be recalculated. It's marked dirty when the size of ring changes. } +// New creates and returns a Ring structure of elements. +// The optional parameter specifies whether using this structure in concurrent safety, +// which is false in default. func New(cap int, safe ...bool) *Ring { return &Ring{ mu: rwmutex.New(safe...), diff --git a/container/gring/gring_unit_test.go b/container/gring/gring_unit_test.go index 9032d9222..144a0d565 100644 --- a/container/gring/gring_unit_test.go +++ b/container/gring/gring_unit_test.go @@ -1,3 +1,9 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + package gring_test import ( diff --git a/container/gset/gset_any_set.go b/container/gset/gset_any_set.go index 175329546..d1826bcb3 100644 --- a/container/gset/gset_any_set.go +++ b/container/gset/gset_any_set.go @@ -16,7 +16,7 @@ import ( ) type Set struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[interface{}]struct{} } @@ -31,7 +31,7 @@ func New(safe ...bool) *Set { func NewSet(safe ...bool) *Set { return &Set{ data: make(map[interface{}]struct{}), - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), } } @@ -44,13 +44,13 @@ func NewFrom(items interface{}, safe ...bool) *Set { } return &Set{ data: m, - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), } } -// Iterator iterates the set with given callback function , +// Iterator iterates the set readonly with given callback function , // if returns true then continue iterating; or false to stop. -func (set *Set) Iterator(f func(v interface{}) bool) *Set { +func (set *Set) Iterator(f func(v interface{}) bool) { set.mu.RLock() defer set.mu.RUnlock() for k, _ := range set.data { @@ -58,77 +58,113 @@ func (set *Set) Iterator(f func(v interface{}) bool) *Set { break } } - return set } // Add adds one or multiple items to the set. -func (set *Set) Add(item ...interface{}) *Set { +func (set *Set) Add(item ...interface{}) { set.mu.Lock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } for _, v := range item { set.data[v] = struct{}{} } set.mu.Unlock() - return set } -// AddIfNotExistFunc adds the returned value of callback function to the set -// if does not exit in the set. -func (set *Set) AddIfNotExistFunc(item interface{}, f func() interface{}) *Set { - if !set.Contains(item) { - set.doAddWithLockCheck(item, f()) +// AddIfNotExist checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set, +// or else it does nothing and returns false. +// +// Note that, if is nil, it does nothing and returns false. +func (set *Set) AddIfNotExist(item interface{}) bool { + if item == nil { + return false } - return set -} - -// AddIfNotExistFuncLock adds the returned value of callback function to the set -// if does not exit in the set. -// -// Note that the callback function is executed in the mutex.Lock of the set. -func (set *Set) AddIfNotExistFuncLock(item interface{}, f func() interface{}) *Set { if !set.Contains(item) { - set.doAddWithLockCheck(item, f) - } - return set -} - -// doAddWithLockCheck checks whether item exists with mutex.Lock, -// if not exists, it adds item to the set or else just returns the existing value. -// -// If is type of , -// it will be executed with mutex.Lock of the set, -// and its return value will be added to the set. -// -// It returns item successfully added.. -func (set *Set) doAddWithLockCheck(item interface{}, value interface{}) interface{} { - set.mu.Lock() - defer set.mu.Unlock() - if _, ok := set.data[item]; !ok && value != nil { - if f, ok := value.(func() interface{}); ok { - item = f() - } else { - item = value + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true } } - if item != nil { - set.data[item] = struct{}{} + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, if is nil, it does nothing and returns false. The function +// is executed without writing lock. +func (set *Set) AddIfNotExistFunc(item interface{}, f func() bool) bool { + if item == nil { + return false } - return item + if !set.Contains(item) { + if f() { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } + } + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, if is nil, it does nothing and returns false. The function +// is executed within writing lock. +func (set *Set) AddIfNotExistFuncLock(item interface{}, f func() bool) bool { + if item == nil { + return false + } + if !set.Contains(item) { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } + if f() { + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } + } + return false } // Contains checks whether the set contains . func (set *Set) Contains(item interface{}) bool { + var ok bool set.mu.RLock() - _, exists := set.data[item] + if set.data != nil { + _, ok = set.data[item] + } set.mu.RUnlock() - return exists + return ok } // Remove deletes from set. -func (set *Set) Remove(item interface{}) *Set { +func (set *Set) Remove(item interface{}) { set.mu.Lock() - delete(set.data, item) + if set.data != nil { + delete(set.data, item) + } set.mu.Unlock() - return set } // Size returns the size of the set. @@ -140,18 +176,19 @@ func (set *Set) Size() int { } // Clear deletes all items of the set. -func (set *Set) Clear() *Set { +func (set *Set) Clear() { set.mu.Lock() set.data = make(map[interface{}]struct{}) set.mu.Unlock() - return set } // Slice returns the a of items of the set as slice. func (set *Set) Slice() []interface{} { set.mu.RLock() - i := 0 - ret := make([]interface{}, len(set.data)) + var ( + i = 0 + ret = make([]interface{}, len(set.data)) + ) for item := range set.data { ret[i] = item i++ @@ -164,9 +201,14 @@ func (set *Set) Slice() []interface{} { func (set *Set) Join(glue string) string { set.mu.RLock() defer set.mu.RUnlock() - buffer := bytes.NewBuffer(nil) - l := len(set.data) - i := 0 + if len(set.data) == 0 { + return "" + } + var ( + l = len(set.data) + i = 0 + buffer = bytes.NewBuffer(nil) + ) for k, _ := range set.data { buffer.WriteString(gconv.String(k)) if i != l-1 { @@ -181,11 +223,13 @@ func (set *Set) Join(glue string) string { func (set *Set) String() string { set.mu.RLock() defer set.mu.RUnlock() - buffer := bytes.NewBuffer(nil) + var ( + s = "" + l = len(set.data) + i = 0 + buffer = bytes.NewBuffer(nil) + ) buffer.WriteByte('[') - s := "" - l := len(set.data) - i := 0 for k, _ := range set.data { s = gconv.String(k) if gstr.IsNumeric(s) { @@ -256,7 +300,7 @@ func (set *Set) IsSubsetOf(other *Set) bool { // Union returns a new set which is the union of and . // Which means, all the items in are in or in . func (set *Set) Union(others ...*Set) (newSet *Set) { - newSet = NewSet(true) + newSet = NewSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -282,7 +326,7 @@ func (set *Set) Union(others ...*Set) (newSet *Set) { // Diff returns a new set which is the difference set from to . // Which means, all the items in are in but not in . func (set *Set) Diff(others ...*Set) (newSet *Set) { - newSet = NewSet(true) + newSet = NewSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -303,7 +347,7 @@ func (set *Set) Diff(others ...*Set) (newSet *Set) { // Intersect returns a new set which is the intersection from to . // Which means, all the items in are in and also in . func (set *Set) Intersect(others ...*Set) (newSet *Set) { - newSet = NewSet(true) + newSet = NewSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -328,7 +372,7 @@ func (set *Set) Intersect(others ...*Set) (newSet *Set) { // It returns the difference between and // if the given set is not the full set of . func (set *Set) Complement(full *Set) (newSet *Set) { - newSet = NewSet(true) + newSet = NewSet() set.mu.RLock() defer set.mu.RUnlock() if set != full { @@ -415,12 +459,11 @@ func (set *Set) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (set *Set) UnmarshalJSON(b []byte) error { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[interface{}]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } var array []interface{} if err := json.Unmarshal(b, &array); err != nil { return err @@ -433,12 +476,11 @@ func (set *Set) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for set. func (set *Set) UnmarshalValue(value interface{}) (err error) { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[interface{}]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[interface{}]struct{}) + } var array []interface{} switch value.(type) { case string, []byte: diff --git a/container/gset/gset_int_set.go b/container/gset/gset_int_set.go index f853700a3..70d810edc 100644 --- a/container/gset/gset_int_set.go +++ b/container/gset/gset_int_set.go @@ -15,7 +15,7 @@ import ( ) type IntSet struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[int]struct{} } @@ -24,7 +24,7 @@ type IntSet struct { // which is false in default. func NewIntSet(safe ...bool) *IntSet { return &IntSet{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[int]struct{}), } } @@ -36,14 +36,14 @@ func NewIntSetFrom(items []int, safe ...bool) *IntSet { m[v] = struct{}{} } return &IntSet{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: m, } } -// Iterator iterates the set with given callback function , +// Iterator iterates the set readonly with given callback function , // if returns true then continue iterating; or false to stop. -func (set *IntSet) Iterator(f func(v int) bool) *IntSet { +func (set *IntSet) Iterator(f func(v int) bool) { set.mu.RLock() defer set.mu.RUnlock() for k, _ := range set.data { @@ -51,75 +51,102 @@ func (set *IntSet) Iterator(f func(v int) bool) *IntSet { break } } - return set } // Add adds one or multiple items to the set. -func (set *IntSet) Add(item ...int) *IntSet { +func (set *IntSet) Add(item ...int) { set.mu.Lock() + if set.data == nil { + set.data = make(map[int]struct{}) + } for _, v := range item { set.data[v] = struct{}{} } set.mu.Unlock() - return set } -// AddIfNotExistFunc adds the returned value of callback function to the set -// if does not exit in the set. -func (set *IntSet) AddIfNotExistFunc(item int, f func() int) *IntSet { +// AddIfNotExist checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set, +// or else it does nothing and returns false. +// +// Note that, if is nil, it does nothing and returns false. +func (set *IntSet) AddIfNotExist(item int) bool { if !set.Contains(item) { - set.doAddWithLockCheck(item, f()) - } - return set -} - -// AddIfNotExistFuncLock adds the returned value of callback function to the set -// if does not exit in the set. -// -// Note that the callback function is executed in the mutex.Lock of the set. -func (set *IntSet) AddIfNotExistFuncLock(item int, f func() int) *IntSet { - if !set.Contains(item) { - set.doAddWithLockCheck(item, f) - } - return set -} - -// doAddWithLockCheck checks whether item exists with mutex.Lock, -// if not exists, it adds item to the set or else just returns the existing value. -// -// If is type of , -// it will be executed with mutex.Lock of the set, -// and its return value will be added to the set. -// -// It returns item successfully added.. -func (set *IntSet) doAddWithLockCheck(item int, value interface{}) int { - set.mu.Lock() - defer set.mu.Unlock() - if _, ok := set.data[item]; !ok && value != nil { - if f, ok := value.(func() int); ok { - item = f() - } else { - item = value.(int) + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[int]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true } } - set.data[item] = struct{}{} - return item + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, the function is executed without writing lock. +func (set *IntSet) AddIfNotExistFunc(item int, f func() bool) bool { + if !set.Contains(item) { + if f() { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[int]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } + } + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, the function is executed without writing lock. +func (set *IntSet) AddIfNotExistFuncLock(item int, f func() bool) bool { + if !set.Contains(item) { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[int]struct{}) + } + if f() { + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } + } + return false } // Contains checks whether the set contains . func (set *IntSet) Contains(item int) bool { + var ok bool set.mu.RLock() - _, exists := set.data[item] + if set.data != nil { + _, ok = set.data[item] + } set.mu.RUnlock() - return exists + return ok } // Remove deletes from set. -func (set *IntSet) Remove(item int) *IntSet { +func (set *IntSet) Remove(item int) { set.mu.Lock() - delete(set.data, item) + if set.data != nil { + delete(set.data, item) + } set.mu.Unlock() - return set } // Size returns the size of the set. @@ -131,18 +158,19 @@ func (set *IntSet) Size() int { } // Clear deletes all items of the set. -func (set *IntSet) Clear() *IntSet { +func (set *IntSet) Clear() { set.mu.Lock() set.data = make(map[int]struct{}) set.mu.Unlock() - return set } // Slice returns the a of items of the set as slice. func (set *IntSet) Slice() []int { set.mu.RLock() - ret := make([]int, len(set.data)) - i := 0 + var ( + i = 0 + ret = make([]int, len(set.data)) + ) for k, _ := range set.data { ret[i] = k i++ @@ -155,9 +183,14 @@ func (set *IntSet) Slice() []int { func (set *IntSet) Join(glue string) string { set.mu.RLock() defer set.mu.RUnlock() - buffer := bytes.NewBuffer(nil) - l := len(set.data) - i := 0 + if len(set.data) == 0 { + return "" + } + var ( + l = len(set.data) + i = 0 + buffer = bytes.NewBuffer(nil) + ) for k, _ := range set.data { buffer.WriteString(gconv.String(k)) if i != l-1 { @@ -227,7 +260,7 @@ func (set *IntSet) IsSubsetOf(other *IntSet) bool { // Union returns a new set which is the union of and . // Which means, all the items in are in or in . func (set *IntSet) Union(others ...*IntSet) (newSet *IntSet) { - newSet = NewIntSet(true) + newSet = NewIntSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -253,7 +286,7 @@ func (set *IntSet) Union(others ...*IntSet) (newSet *IntSet) { // Diff returns a new set which is the difference set from to . // Which means, all the items in are in but not in . func (set *IntSet) Diff(others ...*IntSet) (newSet *IntSet) { - newSet = NewIntSet(true) + newSet = NewIntSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -274,7 +307,7 @@ func (set *IntSet) Diff(others ...*IntSet) (newSet *IntSet) { // Intersect returns a new set which is the intersection from to . // Which means, all the items in are in and also in . func (set *IntSet) Intersect(others ...*IntSet) (newSet *IntSet) { - newSet = NewIntSet(true) + newSet = NewIntSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -299,7 +332,7 @@ func (set *IntSet) Intersect(others ...*IntSet) (newSet *IntSet) { // It returns the difference between and // if the given set is not the full set of . func (set *IntSet) Complement(full *IntSet) (newSet *IntSet) { - newSet = NewIntSet(true) + newSet = NewIntSet() set.mu.RLock() defer set.mu.RUnlock() if set != full { @@ -386,12 +419,11 @@ func (set *IntSet) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (set *IntSet) UnmarshalJSON(b []byte) error { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[int]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[int]struct{}) + } var array []int if err := json.Unmarshal(b, &array); err != nil { return err @@ -404,12 +436,11 @@ func (set *IntSet) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for set. func (set *IntSet) UnmarshalValue(value interface{}) (err error) { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[int]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[int]struct{}) + } var array []int switch value.(type) { case string, []byte: diff --git a/container/gset/gset_str_set.go b/container/gset/gset_str_set.go index cd4ce4723..fe02c817f 100644 --- a/container/gset/gset_str_set.go +++ b/container/gset/gset_str_set.go @@ -16,7 +16,7 @@ import ( ) type StrSet struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex data map[string]struct{} } @@ -25,7 +25,7 @@ type StrSet struct { // which is false in default. func NewStrSet(safe ...bool) *StrSet { return &StrSet{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: make(map[string]struct{}), } } @@ -37,14 +37,14 @@ func NewStrSetFrom(items []string, safe ...bool) *StrSet { m[v] = struct{}{} } return &StrSet{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), data: m, } } -// Iterator iterates the set with given callback function , +// Iterator iterates the set readonly with given callback function , // if returns true then continue iterating; or false to stop. -func (set *StrSet) Iterator(f func(v string) bool) *StrSet { +func (set *StrSet) Iterator(f func(v string) bool) { set.mu.RLock() defer set.mu.RUnlock() for k, _ := range set.data { @@ -52,77 +52,100 @@ func (set *StrSet) Iterator(f func(v string) bool) *StrSet { break } } - return set } // Add adds one or multiple items to the set. -func (set *StrSet) Add(item ...string) *StrSet { +func (set *StrSet) Add(item ...string) { set.mu.Lock() + if set.data == nil { + set.data = make(map[string]struct{}) + } for _, v := range item { set.data[v] = struct{}{} } set.mu.Unlock() - return set } -// AddIfNotExistFunc adds the returned value of callback function to the set -// if does not exit in the set. -func (set *StrSet) AddIfNotExistFunc(item string, f func() string) *StrSet { +// AddIfNotExist checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set, +// or else it does nothing and returns false. +func (set *StrSet) AddIfNotExist(item string) bool { if !set.Contains(item) { - set.doAddWithLockCheck(item, f()) - } - return set -} - -// AddIfNotExistFuncLock adds the returned value of callback function to the set -// if does not exit in the set. -// -// Note that the callback function is executed in the mutex.Lock of the set. -func (set *StrSet) AddIfNotExistFuncLock(item string, f func() string) *StrSet { - if !set.Contains(item) { - set.doAddWithLockCheck(item, f) - } - return set -} - -// doAddWithLockCheck checks whether item exists with mutex.Lock, -// if not exists, it adds item to the set or else just returns the existing value. -// -// If is type of , -// it will be executed with mutex.Lock of the set, -// and its return value will be added to the set. -// -// It returns item successfully added.. -func (set *StrSet) doAddWithLockCheck(item string, value interface{}) string { - set.mu.Lock() - defer set.mu.Unlock() - if _, ok := set.data[item]; !ok && value != nil { - if f, ok := value.(func() string); ok { - item = f() - } else { - item = value.(string) + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[string]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true } } - if item != "" { - set.data[item] = struct{}{} + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, the function is executed without writing lock. +func (set *StrSet) AddIfNotExistFunc(item string, f func() bool) bool { + if !set.Contains(item) { + if f() { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[string]struct{}) + } + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } } - return item + return false +} + +// AddIfNotExistFunc checks whether item exists in the set, +// it adds the item to set and returns true if it does not exists in the set and +// function returns true, or else it does nothing and returns false. +// +// Note that, the function is executed without writing lock. +func (set *StrSet) AddIfNotExistFuncLock(item string, f func() bool) bool { + if !set.Contains(item) { + set.mu.Lock() + defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[string]struct{}) + } + if f() { + if _, ok := set.data[item]; !ok { + set.data[item] = struct{}{} + return true + } + } + } + return false } // Contains checks whether the set contains . func (set *StrSet) Contains(item string) bool { + var ok bool set.mu.RLock() - _, exists := set.data[item] + if set.data != nil { + _, ok = set.data[item] + } set.mu.RUnlock() - return exists + return ok } // Remove deletes from set. -func (set *StrSet) Remove(item string) *StrSet { +func (set *StrSet) Remove(item string) { set.mu.Lock() - delete(set.data, item) + if set.data != nil { + delete(set.data, item) + } set.mu.Unlock() - return set } // Size returns the size of the set. @@ -134,18 +157,19 @@ func (set *StrSet) Size() int { } // Clear deletes all items of the set. -func (set *StrSet) Clear() *StrSet { +func (set *StrSet) Clear() { set.mu.Lock() set.data = make(map[string]struct{}) set.mu.Unlock() - return set } // Slice returns the a of items of the set as slice. func (set *StrSet) Slice() []string { set.mu.RLock() - ret := make([]string, len(set.data)) - i := 0 + var ( + i = 0 + ret = make([]string, len(set.data)) + ) for item := range set.data { ret[i] = item i++ @@ -159,9 +183,14 @@ func (set *StrSet) Slice() []string { func (set *StrSet) Join(glue string) string { set.mu.RLock() defer set.mu.RUnlock() - buffer := bytes.NewBuffer(nil) - l := len(set.data) - i := 0 + if len(set.data) == 0 { + return "" + } + var ( + l = len(set.data) + i = 0 + buffer = bytes.NewBuffer(nil) + ) for k, _ := range set.data { buffer.WriteString(k) if i != l-1 { @@ -176,9 +205,11 @@ func (set *StrSet) Join(glue string) string { func (set *StrSet) String() string { set.mu.RLock() defer set.mu.RUnlock() - buffer := bytes.NewBuffer(nil) - l := len(set.data) - i := 0 + var ( + l = len(set.data) + i = 0 + buffer = bytes.NewBuffer(nil) + ) for k, _ := range set.data { buffer.WriteString(`"` + gstr.QuoteMeta(k, `"\`) + `"`) if i != l-1 { @@ -243,7 +274,7 @@ func (set *StrSet) IsSubsetOf(other *StrSet) bool { // Union returns a new set which is the union of and . // Which means, all the items in are in or in . func (set *StrSet) Union(others ...*StrSet) (newSet *StrSet) { - newSet = NewStrSet(true) + newSet = NewStrSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -269,7 +300,7 @@ func (set *StrSet) Union(others ...*StrSet) (newSet *StrSet) { // Diff returns a new set which is the difference set from to . // Which means, all the items in are in but not in . func (set *StrSet) Diff(others ...*StrSet) (newSet *StrSet) { - newSet = NewStrSet(true) + newSet = NewStrSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -290,7 +321,7 @@ func (set *StrSet) Diff(others ...*StrSet) (newSet *StrSet) { // Intersect returns a new set which is the intersection from to . // Which means, all the items in are in and also in . func (set *StrSet) Intersect(others ...*StrSet) (newSet *StrSet) { - newSet = NewStrSet(true) + newSet = NewStrSet() set.mu.RLock() defer set.mu.RUnlock() for _, other := range others { @@ -315,7 +346,7 @@ func (set *StrSet) Intersect(others ...*StrSet) (newSet *StrSet) { // It returns the difference between and // if the given set is not the full set of . func (set *StrSet) Complement(full *StrSet) (newSet *StrSet) { - newSet = NewStrSet(true) + newSet = NewStrSet() set.mu.RLock() defer set.mu.RUnlock() if set != full { @@ -402,12 +433,11 @@ func (set *StrSet) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (set *StrSet) UnmarshalJSON(b []byte) error { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[string]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[string]struct{}) + } var array []string if err := json.Unmarshal(b, &array); err != nil { return err @@ -420,12 +450,11 @@ func (set *StrSet) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for set. func (set *StrSet) UnmarshalValue(value interface{}) (err error) { - if set.mu == nil { - set.mu = rwmutex.New() - set.data = make(map[string]struct{}) - } set.mu.Lock() defer set.mu.Unlock() + if set.data == nil { + set.data = make(map[string]struct{}) + } var array []string switch value.(type) { case string, []byte: diff --git a/container/gset/gset_z_unit_any_test.go b/container/gset/gset_z_unit_any_test.go index 583d17332..64db0e4ad 100644 --- a/container/gset/gset_z_unit_any_test.go +++ b/container/gset/gset_z_unit_any_test.go @@ -13,6 +13,8 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/util/gconv" "strings" + "sync" + "time" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/gset" @@ -21,10 +23,30 @@ import ( "testing" ) +func TestSet_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var s gset.Set + s.Add(1, 1, 2) + s.Add([]interface{}{3, 4}...) + t.Assert(s.Size(), 4) + t.AssertIN(1, s.Slice()) + t.AssertIN(2, s.Slice()) + t.AssertIN(3, s.Slice()) + t.AssertIN(4, s.Slice()) + t.AssertNI(0, s.Slice()) + t.Assert(s.Contains(4), true) + t.Assert(s.Contains(5), false) + s.Remove(1) + t.Assert(s.Size(), 3) + s.Clear() + t.Assert(s.Size(), 0) + }) +} + func TestSet_New(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.New() - s.Add(1).Add(1).Add(2) + s.Add(1, 1, 2) s.Add([]interface{}{3, 4}...) t.Assert(s.Size(), 4) t.AssertIN(1, s.Slice()) @@ -44,7 +66,7 @@ func TestSet_New(t *testing.T) { func TestSet_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewSet() - s.Add(1).Add(1).Add(2) + s.Add(1, 1, 2) s.Add([]interface{}{3, 4}...) t.Assert(s.Size(), 4) t.AssertIN(1, s.Slice()) @@ -64,7 +86,7 @@ func TestSet_Basic(t *testing.T) { func TestSet_Iterator(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewSet() - s.Add(1).Add(2).Add(3) + s.Add(1, 2, 3) t.Assert(s.Size(), 3) a1 := garray.New(true) @@ -85,7 +107,7 @@ func TestSet_Iterator(t *testing.T) { func TestSet_LockFunc(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewSet() - s.Add(1).Add(2).Add(3) + s.Add(1, 2, 3) t.Assert(s.Size(), 3) s.LockFunc(func(m map[interface{}]struct{}) { delete(m, 1) @@ -105,9 +127,9 @@ func TestSet_Equal(t *testing.T) { s1 := gset.NewSet() s2 := gset.NewSet() s3 := gset.NewSet() - s1.Add(1).Add(2).Add(3) - s2.Add(1).Add(2).Add(3) - s3.Add(1).Add(2).Add(3).Add(4) + s1.Add(1, 2, 3) + s2.Add(1, 2, 3) + s3.Add(1, 2, 3, 4) t.Assert(s1.Equal(s2), true) t.Assert(s1.Equal(s3), false) }) @@ -118,9 +140,9 @@ func TestSet_IsSubsetOf(t *testing.T) { s1 := gset.NewSet() s2 := gset.NewSet() s3 := gset.NewSet() - s1.Add(1).Add(2) - s2.Add(1).Add(2).Add(3) - s3.Add(1).Add(2).Add(3).Add(4) + s1.Add(1, 2) + s2.Add(1, 2, 3) + s3.Add(1, 2, 3, 4) t.Assert(s1.IsSubsetOf(s2), true) t.Assert(s2.IsSubsetOf(s3), true) t.Assert(s1.IsSubsetOf(s3), true) @@ -133,8 +155,8 @@ func TestSet_Union(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewSet() s2 := gset.NewSet() - s1.Add(1).Add(2) - s2.Add(3).Add(4) + s1.Add(1, 2) + s2.Add(3, 4) s3 := s1.Union(s2) t.Assert(s3.Contains(1), true) t.Assert(s3.Contains(2), true) @@ -147,8 +169,8 @@ func TestSet_Diff(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewSet() s2 := gset.NewSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Diff(s2) t.Assert(s3.Contains(1), true) t.Assert(s3.Contains(2), true) @@ -161,8 +183,8 @@ func TestSet_Intersect(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewSet() s2 := gset.NewSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Intersect(s2) t.Assert(s3.Contains(1), false) t.Assert(s3.Contains(2), false) @@ -175,8 +197,8 @@ func TestSet_Complement(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewSet() s2 := gset.NewSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Complement(s2) t.Assert(s3.Contains(1), false) t.Assert(s3.Contains(2), false) @@ -203,9 +225,9 @@ func TestNewFrom(t *testing.T) { func TestNew(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New() - s1.Add("a").Add(2) + s1.Add("a", 2) s2 := gset.New(true) - s2.Add("b").Add(3) + s2.Add("b", 3) t.Assert(s1.Contains("a"), true) }) @@ -214,13 +236,13 @@ func TestNew(t *testing.T) { func TestSet_Join(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) - s1.Add("a").Add("a1").Add("b").Add("c") + s1.Add("a", "a1", "b", "c") str1 := s1.Join(",") t.Assert(strings.Contains(str1, "a1"), true) }) gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) - s1.Add("a").Add(`"b"`).Add(`\c`) + s1.Add("a", `"b"`, `\c`) str1 := s1.Join(",") t.Assert(strings.Contains(str1, `"b"`), true) t.Assert(strings.Contains(str1, `\c`), true) @@ -231,7 +253,7 @@ func TestSet_Join(t *testing.T) { func TestSet_String(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) - s1.Add("a").Add("a2").Add("b").Add("c") + s1.Add("a", "a2", "b", "c") str1 := s1.String() t.Assert(strings.Contains(str1, "["), true) t.Assert(strings.Contains(str1, "]"), true) @@ -243,8 +265,8 @@ func TestSet_Merge(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) s2 := gset.New(true) - s1.Add("a").Add("a2").Add("b").Add("c") - s2.Add("b").Add("b1").Add("e").Add("f") + s1.Add("a", "a2", "b", "c") + s2.Add("b", "b1", "e", "f") ss := s1.Merge(s2) t.Assert(ss.Contains("a2"), true) t.Assert(ss.Contains("b1"), true) @@ -255,7 +277,7 @@ func TestSet_Merge(t *testing.T) { func TestSet_Sum(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) - s1.Add(1).Add(2).Add(3).Add(4) + s1.Add(1, 2, 3, 4) t.Assert(s1.Sum(), int(10)) }) @@ -264,7 +286,7 @@ func TestSet_Sum(t *testing.T) { func TestSet_Pop(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.New(true) - s.Add(1).Add(2).Add(3).Add(4) + s.Add(1, 2, 3, 4) t.Assert(s.Size(), 4) t.AssertIN(s.Pop(), []int{1, 2, 3, 4}) t.Assert(s.Size(), 3) @@ -274,7 +296,7 @@ func TestSet_Pop(t *testing.T) { func TestSet_Pops(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.New(true) - s.Add(1).Add(2).Add(3).Add(4) + s.Add(1, 2, 3, 4) t.Assert(s.Size(), 4) t.Assert(s.Pops(0), nil) t.AssertIN(s.Pops(1), []int{1, 2, 3, 4}) @@ -324,43 +346,71 @@ func TestSet_Json(t *testing.T) { }) } +func TestSet_AddIfNotExist(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.New(true) + s.Add(1) + t.Assert(s.Contains(1), true) + t.Assert(s.AddIfNotExist(1), false) + t.Assert(s.AddIfNotExist(2), true) + t.Assert(s.Contains(2), true) + t.Assert(s.AddIfNotExist(2), false) + t.Assert(s.Contains(2), true) + }) +} + func TestSet_AddIfNotExistFunc(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.New(true) s.Add(1) t.Assert(s.Contains(1), true) t.Assert(s.Contains(2), false) - - s.AddIfNotExistFunc(2, func() interface{} { - return 3 - }) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return false }), false) t.Assert(s.Contains(2), false) - t.Assert(s.Contains(3), true) - - s.AddIfNotExistFunc(3, func() interface{} { - return 4 - }) - t.Assert(s.Contains(3), true) - t.Assert(s.Contains(4), false) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return true }), true) + t.Assert(s.Contains(2), true) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return true }), false) + t.Assert(s.Contains(2), true) }) - gtest.C(t, func(t *gtest.T) { s := gset.New(true) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + r := s.AddIfNotExistFunc(1, func() bool { + time.Sleep(100 * time.Millisecond) + return true + }) + t.Assert(r, false) + }() s.Add(1) - t.Assert(s.Contains(1), true) - t.Assert(s.Contains(2), false) + wg.Wait() + }) +} - s.AddIfNotExistFuncLock(2, func() interface{} { - return 3 - }) - t.Assert(s.Contains(2), false) - t.Assert(s.Contains(3), true) - - s.AddIfNotExistFuncLock(3, func() interface{} { - return 4 - }) - t.Assert(s.Contains(3), true) - t.Assert(s.Contains(4), false) +func TestSet_AddIfNotExistFuncLock(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.New(true) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock(1, func() bool { + time.Sleep(500 * time.Millisecond) + return true + }) + t.Assert(r, true) + }() + time.Sleep(100 * time.Millisecond) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock(1, func() bool { + return true + }) + t.Assert(r, false) + }() + wg.Wait() }) } diff --git a/container/gset/gset_z_unit_int_test.go b/container/gset/gset_z_unit_int_test.go index ee36c7805..a749fa27d 100644 --- a/container/gset/gset_z_unit_int_test.go +++ b/container/gset/gset_z_unit_int_test.go @@ -13,17 +13,39 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/util/gconv" "strings" + "sync" "testing" + "time" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/gset" "github.com/gogf/gf/test/gtest" ) +func TestIntSet_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var s gset.IntSet + s.Add(1, 1, 2) + s.Add([]int{3, 4}...) + t.Assert(s.Size(), 4) + t.AssertIN(1, s.Slice()) + t.AssertIN(2, s.Slice()) + t.AssertIN(3, s.Slice()) + t.AssertIN(4, s.Slice()) + t.AssertNI(0, s.Slice()) + t.Assert(s.Contains(4), true) + t.Assert(s.Contains(5), false) + s.Remove(1) + t.Assert(s.Size(), 3) + s.Clear() + t.Assert(s.Size(), 0) + }) +} + func TestIntSet_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewIntSet() - s.Add(1).Add(1).Add(2) + s.Add(1, 1, 2) s.Add([]int{3, 4}...) t.Assert(s.Size(), 4) t.AssertIN(1, s.Slice()) @@ -43,7 +65,7 @@ func TestIntSet_Basic(t *testing.T) { func TestIntSet_Iterator(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewIntSet() - s.Add(1).Add(2).Add(3) + s.Add(1, 2, 3) t.Assert(s.Size(), 3) a1 := garray.New(true) @@ -64,7 +86,7 @@ func TestIntSet_Iterator(t *testing.T) { func TestIntSet_LockFunc(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewIntSet() - s.Add(1).Add(2).Add(3) + s.Add(1, 2, 3) t.Assert(s.Size(), 3) s.LockFunc(func(m map[int]struct{}) { delete(m, 1) @@ -84,9 +106,9 @@ func TestIntSet_Equal(t *testing.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() s3 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) - s2.Add(1).Add(2).Add(3) - s3.Add(1).Add(2).Add(3).Add(4) + s1.Add(1, 2, 3) + s2.Add(1, 2, 3) + s3.Add(1, 2, 3, 4) t.Assert(s1.Equal(s2), true) t.Assert(s1.Equal(s3), false) }) @@ -97,9 +119,9 @@ func TestIntSet_IsSubsetOf(t *testing.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() s3 := gset.NewIntSet() - s1.Add(1).Add(2) - s2.Add(1).Add(2).Add(3) - s3.Add(1).Add(2).Add(3).Add(4) + s1.Add(1, 2) + s2.Add(1, 2, 3) + s3.Add(1, 2, 3, 4) t.Assert(s1.IsSubsetOf(s2), true) t.Assert(s2.IsSubsetOf(s3), true) t.Assert(s1.IsSubsetOf(s3), true) @@ -112,8 +134,8 @@ func TestIntSet_Union(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() - s1.Add(1).Add(2) - s2.Add(3).Add(4) + s1.Add(1, 2) + s2.Add(3, 4) s3 := s1.Union(s2) t.Assert(s3.Contains(1), true) t.Assert(s3.Contains(2), true) @@ -126,8 +148,8 @@ func TestIntSet_Diff(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Diff(s2) t.Assert(s3.Contains(1), true) t.Assert(s3.Contains(2), true) @@ -140,8 +162,8 @@ func TestIntSet_Intersect(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Intersect(s2) t.Assert(s3.Contains(1), false) t.Assert(s3.Contains(2), false) @@ -154,8 +176,8 @@ func TestIntSet_Complement(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Complement(s2) t.Assert(s3.Contains(1), false) t.Assert(s3.Contains(2), false) @@ -167,7 +189,7 @@ func TestIntSet_Complement(t *testing.T) { func TestIntSet_Size(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet(true) - s1.Add(1).Add(2).Add(3) + s1.Add(1, 2, 3) t.Assert(s1.Size(), 3) }) @@ -178,8 +200,8 @@ func TestIntSet_Merge(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() s2 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) - s2.Add(3).Add(4).Add(5) + s1.Add(1, 2, 3) + s2.Add(3, 4, 5) s3 := s1.Merge(s2) t.Assert(s3.Contains(1), true) t.Assert(s3.Contains(5), true) @@ -190,7 +212,7 @@ func TestIntSet_Merge(t *testing.T) { func TestIntSet_Join(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) + s1.Add(1, 2, 3) s3 := s1.Join(",") t.Assert(strings.Contains(s3, "1"), true) t.Assert(strings.Contains(s3, "2"), true) @@ -201,7 +223,7 @@ func TestIntSet_Join(t *testing.T) { func TestIntSet_String(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) + s1.Add(1, 2, 3) s3 := s1.String() t.Assert(strings.Contains(s3, "["), true) t.Assert(strings.Contains(s3, "]"), true) @@ -214,9 +236,9 @@ func TestIntSet_String(t *testing.T) { func TestIntSet_Sum(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewIntSet() - s1.Add(1).Add(2).Add(3) + s1.Add(1, 2, 3) s2 := gset.NewIntSet() - s2.Add(5).Add(6).Add(7) + s2.Add(5, 6, 7) t.Assert(s2.Sum(), 18) }) @@ -226,7 +248,7 @@ func TestIntSet_Sum(t *testing.T) { func TestIntSet_Pop(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewIntSet() - s.Add(4).Add(2).Add(3) + s.Add(4, 2, 3) t.Assert(s.Size(), 3) t.AssertIN(s.Pop(), []int{4, 2, 3}) t.AssertIN(s.Pop(), []int{4, 2, 3}) @@ -237,7 +259,7 @@ func TestIntSet_Pop(t *testing.T) { func TestIntSet_Pops(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewIntSet() - s.Add(1).Add(4).Add(2).Add(3) + s.Add(1, 4, 2, 3) t.Assert(s.Size(), 4) t.Assert(s.Pops(0), nil) t.AssertIN(s.Pops(1), []int{1, 4, 2, 3}) @@ -258,6 +280,74 @@ func TestIntSet_Pops(t *testing.T) { }) } +func TestIntSet_AddIfNotExist(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewIntSet(true) + s.Add(1) + t.Assert(s.Contains(1), true) + t.Assert(s.AddIfNotExist(1), false) + t.Assert(s.AddIfNotExist(2), true) + t.Assert(s.Contains(2), true) + t.Assert(s.AddIfNotExist(2), false) + t.Assert(s.Contains(2), true) + }) +} + +func TestIntSet_AddIfNotExistFunc(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewIntSet(true) + s.Add(1) + t.Assert(s.Contains(1), true) + t.Assert(s.Contains(2), false) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return false }), false) + t.Assert(s.Contains(2), false) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return true }), true) + t.Assert(s.Contains(2), true) + t.Assert(s.AddIfNotExistFunc(2, func() bool { return true }), false) + t.Assert(s.Contains(2), true) + }) + gtest.C(t, func(t *gtest.T) { + s := gset.NewIntSet(true) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + r := s.AddIfNotExistFunc(1, func() bool { + time.Sleep(100 * time.Millisecond) + return true + }) + t.Assert(r, false) + }() + s.Add(1) + wg.Wait() + }) +} + +func TestIntSet_AddIfNotExistFuncLock(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewIntSet(true) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock(1, func() bool { + time.Sleep(500 * time.Millisecond) + return true + }) + t.Assert(r, true) + }() + time.Sleep(100 * time.Millisecond) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock(1, func() bool { + return true + }) + t.Assert(r, false) + }() + wg.Wait() + }) +} + func TestIntSet_Json(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := []int{1, 3, 2, 4} @@ -287,46 +377,6 @@ func TestIntSet_Json(t *testing.T) { }) } -func TestIntSet_AddIfNotExistFunc(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - s := gset.NewIntSet(true) - s.Add(1) - t.Assert(s.Contains(1), true) - t.Assert(s.Contains(2), false) - - s.AddIfNotExistFunc(2, func() int { - return 3 - }) - t.Assert(s.Contains(2), false) - t.Assert(s.Contains(3), true) - - s.AddIfNotExistFunc(3, func() int { - return 4 - }) - t.Assert(s.Contains(3), true) - t.Assert(s.Contains(4), false) - }) - - gtest.C(t, func(t *gtest.T) { - s := gset.NewIntSet(true) - s.Add(1) - t.Assert(s.Contains(1), true) - t.Assert(s.Contains(2), false) - - s.AddIfNotExistFuncLock(2, func() int { - return 3 - }) - t.Assert(s.Contains(2), false) - t.Assert(s.Contains(3), true) - - s.AddIfNotExistFuncLock(3, func() int { - return 4 - }) - t.Assert(s.Contains(3), true) - t.Assert(s.Contains(4), false) - }) -} - func TestIntSet_UnmarshalValue(t *testing.T) { type V struct { Name string diff --git a/container/gset/gset_z_unit_str_test.go b/container/gset/gset_z_unit_str_test.go index 7f0fbb3a6..ce654c3fa 100644 --- a/container/gset/gset_z_unit_str_test.go +++ b/container/gset/gset_z_unit_str_test.go @@ -13,17 +13,39 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/util/gconv" "strings" + "sync" "testing" + "time" "github.com/gogf/gf/container/garray" "github.com/gogf/gf/container/gset" "github.com/gogf/gf/test/gtest" ) +func TestStrSet_Var(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var s gset.StrSet + s.Add("1", "1", "2") + s.Add([]string{"3", "4"}...) + t.Assert(s.Size(), 4) + t.AssertIN("1", s.Slice()) + t.AssertIN("2", s.Slice()) + t.AssertIN("3", s.Slice()) + t.AssertIN("4", s.Slice()) + t.AssertNI("0", s.Slice()) + t.Assert(s.Contains("4"), true) + t.Assert(s.Contains("5"), false) + s.Remove("1") + t.Assert(s.Size(), 3) + s.Clear() + t.Assert(s.Size(), 0) + }) +} + func TestStrSet_Basic(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewStrSet() - s.Add("1").Add("1").Add("2") + s.Add("1", "1", "2") s.Add([]string{"3", "4"}...) t.Assert(s.Size(), 4) t.AssertIN("1", s.Slice()) @@ -43,7 +65,7 @@ func TestStrSet_Basic(t *testing.T) { func TestStrSet_Iterator(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewStrSet() - s.Add("1").Add("2").Add("3") + s.Add("1", "2", "3") t.Assert(s.Size(), 3) a1 := garray.New(true) @@ -64,7 +86,7 @@ func TestStrSet_Iterator(t *testing.T) { func TestStrSet_LockFunc(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := gset.NewStrSet() - s.Add("1").Add("2").Add("3") + s.Add("1", "2", "3") t.Assert(s.Size(), 3) s.LockFunc(func(m map[string]struct{}) { delete(m, "1") @@ -84,9 +106,9 @@ func TestStrSet_Equal(t *testing.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() s3 := gset.NewStrSet() - s1.Add("1").Add("2").Add("3") - s2.Add("1").Add("2").Add("3") - s3.Add("1").Add("2").Add("3").Add("4") + s1.Add("1", "2", "3") + s2.Add("1", "2", "3") + s3.Add("1", "2", "3", "4") t.Assert(s1.Equal(s2), true) t.Assert(s1.Equal(s3), false) }) @@ -97,9 +119,9 @@ func TestStrSet_IsSubsetOf(t *testing.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() s3 := gset.NewStrSet() - s1.Add("1").Add("2") - s2.Add("1").Add("2").Add("3") - s3.Add("1").Add("2").Add("3").Add("4") + s1.Add("1", "2") + s2.Add("1", "2", "3") + s3.Add("1", "2", "3", "4") t.Assert(s1.IsSubsetOf(s2), true) t.Assert(s2.IsSubsetOf(s3), true) t.Assert(s1.IsSubsetOf(s3), true) @@ -112,8 +134,8 @@ func TestStrSet_Union(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() - s1.Add("1").Add("2") - s2.Add("3").Add("4") + s1.Add("1", "2") + s2.Add("3", "4") s3 := s1.Union(s2) t.Assert(s3.Contains("1"), true) t.Assert(s3.Contains("2"), true) @@ -126,8 +148,8 @@ func TestStrSet_Diff(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() - s1.Add("1").Add("2").Add("3") - s2.Add("3").Add("4").Add("5") + s1.Add("1", "2", "3") + s2.Add("3", "4", "5") s3 := s1.Diff(s2) t.Assert(s3.Contains("1"), true) t.Assert(s3.Contains("2"), true) @@ -140,8 +162,8 @@ func TestStrSet_Intersect(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() - s1.Add("1").Add("2").Add("3") - s2.Add("3").Add("4").Add("5") + s1.Add("1", "2", "3") + s2.Add("3", "4", "5") s3 := s1.Intersect(s2) t.Assert(s3.Contains("1"), false) t.Assert(s3.Contains("2"), false) @@ -154,8 +176,8 @@ func TestStrSet_Complement(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() - s1.Add("1").Add("2").Add("3") - s2.Add("3").Add("4").Add("5") + s1.Add("1", "2", "3") + s2.Add("3", "4", "5") s3 := s1.Complement(s2) t.Assert(s3.Contains("1"), false) t.Assert(s3.Contains("2"), false) @@ -179,8 +201,8 @@ func TestStrSet_Merge(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() s2 := gset.NewStrSet() - s1.Add("1").Add("2").Add("3") - s2.Add("3").Add("4").Add("5") + s1.Add("1", "2", "3") + s2.Add("3", "4", "5") s3 := s1.Merge(s2) t.Assert(s3.Contains("1"), true) t.Assert(s3.Contains("6"), false) @@ -207,7 +229,7 @@ func TestStrSet_Join(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSet() - s1.Add("a").Add(`"b"`).Add(`\c`) + s1.Add("a", `"b"`, `\c`) str1 := s1.Join(",") t.Assert(strings.Contains(str1, `"b"`), true) t.Assert(strings.Contains(str1, `\c`), true) @@ -225,7 +247,7 @@ func TestStrSet_String(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.New(true) - s1.Add("a").Add("a2").Add("b").Add("c") + s1.Add("a", "a2", "b", "c") str1 := s1.String() t.Assert(strings.Contains(str1, "["), true) t.Assert(strings.Contains(str1, "]"), true) @@ -253,7 +275,7 @@ func TestStrSet_Size(t *testing.T) { func TestStrSet_Remove(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := gset.NewStrSetFrom([]string{"a", "b", "c"}, true) - s1 = s1.Remove("b") + s1.Remove("b") t.Assert(s1.Contains("b"), false) t.Assert(s1.Contains("c"), true) }) @@ -294,6 +316,74 @@ func TestStrSet_Pops(t *testing.T) { }) } +func TestStrSet_AddIfNotExist(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewStrSet(true) + s.Add("1") + t.Assert(s.Contains("1"), true) + t.Assert(s.AddIfNotExist("1"), false) + t.Assert(s.AddIfNotExist("2"), true) + t.Assert(s.Contains("2"), true) + t.Assert(s.AddIfNotExist("2"), false) + t.Assert(s.Contains("2"), true) + }) +} + +func TestStrSet_AddIfNotExistFunc(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewStrSet(true) + s.Add("1") + t.Assert(s.Contains("1"), true) + t.Assert(s.Contains("2"), false) + t.Assert(s.AddIfNotExistFunc("2", func() bool { return false }), false) + t.Assert(s.Contains("2"), false) + t.Assert(s.AddIfNotExistFunc("2", func() bool { return true }), true) + t.Assert(s.Contains("2"), true) + t.Assert(s.AddIfNotExistFunc("2", func() bool { return true }), false) + t.Assert(s.Contains("2"), true) + }) + gtest.C(t, func(t *gtest.T) { + s := gset.NewStrSet(true) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + r := s.AddIfNotExistFunc("1", func() bool { + time.Sleep(100 * time.Millisecond) + return true + }) + t.Assert(r, false) + }() + s.Add("1") + wg.Wait() + }) +} + +func TestStrSet_AddIfNotExistFuncLock(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := gset.NewStrSet(true) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock("1", func() bool { + time.Sleep(500 * time.Millisecond) + return true + }) + t.Assert(r, true) + }() + time.Sleep(100 * time.Millisecond) + go func() { + defer wg.Done() + r := s.AddIfNotExistFuncLock("1", func() bool { + return true + }) + t.Assert(r, false) + }() + wg.Wait() + }) +} + func TestStrSet_Json(t *testing.T) { gtest.C(t, func(t *gtest.T) { s1 := []string{"a", "b", "d", "c"} @@ -323,46 +413,6 @@ func TestStrSet_Json(t *testing.T) { }) } -func TestStrSet_AddIfNotExistFunc(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - s := gset.NewStrSet(true) - s.Add("1") - t.Assert(s.Contains("1"), true) - t.Assert(s.Contains("2"), false) - - s.AddIfNotExistFunc("2", func() string { - return "3" - }) - t.Assert(s.Contains("2"), false) - t.Assert(s.Contains("3"), true) - - s.AddIfNotExistFunc("3", func() string { - return "4" - }) - t.Assert(s.Contains("3"), true) - t.Assert(s.Contains("4"), false) - }) - - gtest.C(t, func(t *gtest.T) { - s := gset.NewStrSet(true) - s.Add("1") - t.Assert(s.Contains("1"), true) - t.Assert(s.Contains("2"), false) - - s.AddIfNotExistFuncLock("2", func() string { - return "3" - }) - t.Assert(s.Contains("2"), false) - t.Assert(s.Contains("3"), true) - - s.AddIfNotExistFuncLock("3", func() string { - return "4" - }) - t.Assert(s.Contains("3"), true) - t.Assert(s.Contains("4"), false) - }) -} - func TestStrSet_UnmarshalValue(t *testing.T) { type V struct { Name string diff --git a/container/gtree/gtree_avltree.go b/container/gtree/gtree_avltree.go index ba41adce4..aa195aa1c 100644 --- a/container/gtree/gtree_avltree.go +++ b/container/gtree/gtree_avltree.go @@ -18,7 +18,7 @@ import ( // AVLTree holds elements of the AVL tree. type AVLTree struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex root *AVLTreeNode comparator func(v1, v2 interface{}) int size int @@ -38,7 +38,7 @@ type AVLTreeNode struct { // which is false in default. func NewAVLTree(comparator func(v1, v2 interface{}) int, safe ...bool) *AVLTree { return &AVLTree{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), comparator: comparator, } } @@ -55,7 +55,7 @@ func NewAVLTreeFrom(comparator func(v1, v2 interface{}) int, data map[interface{ } // Clone returns a new tree with a copy of current tree. -func (tree *AVLTree) Clone(safe ...bool) *AVLTree { +func (tree *AVLTree) Clone() *AVLTree { newTree := NewAVLTree(tree.comparator, !tree.mu.IsSafe()) newTree.Sets(tree.Map()) return newTree @@ -93,7 +93,7 @@ func (tree *AVLTree) Search(key interface{}) (value interface{}, found bool) { func (tree *AVLTree) doSearch(key interface{}) (node *AVLTreeNode, found bool) { node = tree.root for node != nil { - cmp := tree.comparator(key, node.Key) + cmp := tree.getComparator()(key, node.Key) switch { case cmp == 0: return node, true @@ -331,7 +331,7 @@ func (tree *AVLTree) Floor(key interface{}) (floor *AVLTreeNode, found bool) { defer tree.mu.RUnlock() n := tree.root for n != nil { - c := tree.comparator(key, n.Key) + c := tree.getComparator()(key, n.Key) switch { case c == 0: return n, true @@ -361,7 +361,7 @@ func (tree *AVLTree) Ceiling(key interface{}) (ceiling *AVLTreeNode, found bool) defer tree.mu.RUnlock() n := tree.root for n != nil { - c := tree.comparator(key, n.Key) + c := tree.getComparator()(key, n.Key) switch { case c == 0: return n, true @@ -465,7 +465,7 @@ func (tree *AVLTree) IteratorFrom(key interface{}, match bool, f func(key, value tree.IteratorAscFrom(key, match, f) } -// IteratorAsc iterates the tree in ascending order with given callback function . +// IteratorAsc iterates the tree readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *AVLTree) IteratorAsc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -473,7 +473,7 @@ func (tree *AVLTree) IteratorAsc(f func(key, value interface{}) bool) { tree.doIteratorAsc(tree.bottom(0), f) } -// IteratorAscFrom iterates the tree in ascending order with given callback function . +// IteratorAscFrom iterates the tree readonly in ascending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -499,7 +499,7 @@ func (tree *AVLTree) doIteratorAsc(node *AVLTreeNode, f func(key, value interfac } } -// IteratorDesc iterates the tree in descending order with given callback function . +// IteratorDesc iterates the tree readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *AVLTree) IteratorDesc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -507,7 +507,7 @@ func (tree *AVLTree) IteratorDesc(f func(key, value interface{}) bool) { tree.doIteratorDesc(tree.bottom(1), f) } -// IteratorDescFrom iterates the tree in descending order with given callback function . +// IteratorDescFrom iterates the tree readonly in descending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -541,7 +541,7 @@ func (tree *AVLTree) put(key interface{}, value interface{}, p *AVLTreeNode, qp return true } - c := tree.comparator(key, q.Key) + c := tree.getComparator()(key, q.Key) if c == 0 { q.Key = key q.Value = value @@ -566,7 +566,7 @@ func (tree *AVLTree) remove(key interface{}, qp **AVLTreeNode) (value interface{ return nil, false } - c := tree.comparator(key, q.Key) + c := tree.getComparator()(key, q.Key) if c == 0 { tree.size-- value = q.Value @@ -784,3 +784,12 @@ func output(node *AVLTreeNode, prefix string, isTail bool, str *string) { func (tree *AVLTree) MarshalJSON() ([]byte, error) { return json.Marshal(tree.Map()) } + +// getComparator returns the comparator if it's previously set, +// or else it panics. +func (tree *AVLTree) getComparator() func(a, b interface{}) int { + if tree.comparator == nil { + panic("comparator is missing for tree") + } + return tree.comparator +} diff --git a/container/gtree/gtree_btree.go b/container/gtree/gtree_btree.go index b174cdd28..05c567591 100644 --- a/container/gtree/gtree_btree.go +++ b/container/gtree/gtree_btree.go @@ -20,7 +20,7 @@ import ( // BTree holds elements of the B-tree. type BTree struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex root *BTreeNode comparator func(v1, v2 interface{}) int size int // Total number of keys in the tree @@ -50,7 +50,7 @@ func NewBTree(m int, comparator func(v1, v2 interface{}) int, safe ...bool) *BTr } return &BTree{ comparator: comparator, - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), m: m, } } @@ -67,7 +67,7 @@ func NewBTreeFrom(m int, comparator func(v1, v2 interface{}) int, data map[inter } // Clone returns a new tree with a copy of current tree. -func (tree *BTree) Clone(safe ...bool) *BTree { +func (tree *BTree) Clone() *BTree { newTree := NewBTree(tree.m, tree.comparator, !tree.mu.IsSafe()) newTree.Sets(tree.Map()) return newTree @@ -406,7 +406,7 @@ func (tree *BTree) IteratorFrom(key interface{}, match bool, f func(key, value i tree.IteratorAscFrom(key, match, f) } -// IteratorAsc iterates the tree in ascending order with given callback function . +// IteratorAsc iterates the tree readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *BTree) IteratorAsc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -418,7 +418,7 @@ func (tree *BTree) IteratorAsc(f func(key, value interface{}) bool) { tree.doIteratorAsc(node, node.Entries[0], 0, f) } -// IteratorAscFrom iterates the tree in ascending order with given callback function . +// IteratorAscFrom iterates the tree readonly in ascending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -479,7 +479,7 @@ loop: } } -// IteratorDesc iterates the tree in descending order with given callback function . +// IteratorDesc iterates the tree readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *BTree) IteratorDesc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -493,7 +493,7 @@ func (tree *BTree) IteratorDesc(f func(key, value interface{}) bool) { tree.doIteratorDesc(node, entry, index, f) } -// IteratorDescFrom iterates the tree in descending order with given callback function . +// IteratorDescFrom iterates the tree readonly in descending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -510,7 +510,7 @@ func (tree *BTree) IteratorDescFrom(key interface{}, match bool, f func(key, val } } -// IteratorDesc iterates the tree in descending order with given callback function . +// IteratorDesc iterates the tree readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *BTree) doIteratorDesc(node *BTreeNode, entry *BTreeEntry, index int, f func(key, value interface{}) bool) { first := true @@ -621,7 +621,7 @@ func (tree *BTree) search(node *BTreeNode, key interface{}) (index int, found bo low, mid, high := 0, 0, len(node.Entries)-1 for low <= high { mid = (high + low) / 2 - compare := tree.comparator(key, node.Entries[mid].Key) + compare := tree.getComparator()(key, node.Entries[mid].Key) switch { case compare > 0: low = mid + 1 @@ -934,3 +934,12 @@ func (tree *BTree) deleteChild(node *BTreeNode, index int) { func (tree *BTree) MarshalJSON() ([]byte, error) { return json.Marshal(tree.Map()) } + +// getComparator returns the comparator if it's previously set, +// or else it panics. +func (tree *BTree) getComparator() func(a, b interface{}) int { + if tree.comparator == nil { + panic("comparator is missing for tree") + } + return tree.comparator +} diff --git a/container/gtree/gtree_redblacktree.go b/container/gtree/gtree_redblacktree.go index 0089f5a51..f23d7838d 100644 --- a/container/gtree/gtree_redblacktree.go +++ b/container/gtree/gtree_redblacktree.go @@ -24,7 +24,7 @@ const ( // RedBlackTree holds elements of the red-black tree. type RedBlackTree struct { - mu *rwmutex.RWMutex + mu rwmutex.RWMutex root *RedBlackTreeNode size int comparator func(v1, v2 interface{}) int @@ -45,7 +45,7 @@ type RedBlackTreeNode struct { // which is false in default. func NewRedBlackTree(comparator func(v1, v2 interface{}) int, safe ...bool) *RedBlackTree { return &RedBlackTree{ - mu: rwmutex.New(safe...), + mu: rwmutex.Create(safe...), comparator: comparator, } } @@ -82,7 +82,7 @@ func (tree *RedBlackTree) SetComparator(comparator func(a, b interface{}) int) { } // Clone returns a new tree with a copy of current tree. -func (tree *RedBlackTree) Clone(safe ...bool) *RedBlackTree { +func (tree *RedBlackTree) Clone() *RedBlackTree { newTree := NewRedBlackTree(tree.comparator, !tree.mu.IsSafe()) newTree.Sets(tree.Map()) return newTree @@ -109,14 +109,14 @@ func (tree *RedBlackTree) doSet(key interface{}, value interface{}) { insertedNode := (*RedBlackTreeNode)(nil) if tree.root == nil { // Assert key is of comparator's type for initial tree - tree.comparator(key, key) + tree.getComparator()(key, key) tree.root = &RedBlackTreeNode{Key: key, Value: value, color: red} insertedNode = tree.root } else { node := tree.root loop := true for loop { - compare := tree.comparator(key, node.Key) + compare := tree.getComparator()(key, node.Key) switch { case compare == 0: //node.Key = key @@ -337,8 +337,10 @@ func (tree *RedBlackTree) Size() int { // Keys returns all keys in asc order. func (tree *RedBlackTree) Keys() []interface{} { - keys := make([]interface{}, tree.Size()) - index := 0 + var ( + keys = make([]interface{}, tree.Size()) + index = 0 + ) tree.IteratorAsc(func(key, value interface{}) bool { keys[index] = key index++ @@ -349,8 +351,10 @@ func (tree *RedBlackTree) Keys() []interface{} { // Values returns all values in asc order based on the key. func (tree *RedBlackTree) Values() []interface{} { - values := make([]interface{}, tree.Size()) - index := 0 + var ( + values = make([]interface{}, tree.Size()) + index = 0 + ) tree.IteratorAsc(func(key, value interface{}) bool { values[index] = value index++ @@ -440,7 +444,7 @@ func (tree *RedBlackTree) Floor(key interface{}) (floor *RedBlackTreeNode, found defer tree.mu.RUnlock() n := tree.root for n != nil { - compare := tree.comparator(key, n.Key) + compare := tree.getComparator()(key, n.Key) switch { case compare == 0: return n, true @@ -468,7 +472,7 @@ func (tree *RedBlackTree) Ceiling(key interface{}) (ceiling *RedBlackTreeNode, f defer tree.mu.RUnlock() n := tree.root for n != nil { - compare := tree.comparator(key, n.Key) + compare := tree.getComparator()(key, n.Key) switch { case compare == 0: return n, true @@ -495,7 +499,7 @@ func (tree *RedBlackTree) IteratorFrom(key interface{}, match bool, f func(key, tree.IteratorAscFrom(key, match, f) } -// IteratorAsc iterates the tree in ascending order with given callback function . +// IteratorAsc iterates the tree readonly in ascending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *RedBlackTree) IteratorAsc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -503,7 +507,7 @@ func (tree *RedBlackTree) IteratorAsc(f func(key, value interface{}) bool) { tree.doIteratorAsc(tree.leftNode(), f) } -// IteratorAscFrom iterates the tree in ascending order with given callback function . +// IteratorAscFrom iterates the tree readonly in ascending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -539,14 +543,14 @@ loop: old := node for node.parent != nil { node = node.parent - if tree.comparator(old.Key, node.Key) <= 0 { + if tree.getComparator()(old.Key, node.Key) <= 0 { goto loop } } } } -// IteratorDesc iterates the tree in descending order with given callback function . +// IteratorDesc iterates the tree readonly in descending order with given callback function . // If returns true, then it continues iterating; or false to stop. func (tree *RedBlackTree) IteratorDesc(f func(key, value interface{}) bool) { tree.mu.RLock() @@ -554,7 +558,7 @@ func (tree *RedBlackTree) IteratorDesc(f func(key, value interface{}) bool) { tree.doIteratorDesc(tree.rightNode(), f) } -// IteratorDescFrom iterates the tree in descending order with given callback function . +// IteratorDescFrom iterates the tree readonly in descending order with given callback function . // The parameter specifies the start entry for iterating. The specifies whether // starting iterating if the is fully matched, or else using index searching iterating. // If returns true, then it continues iterating; or false to stop. @@ -590,7 +594,7 @@ loop: old := node for node.parent != nil { node = node.parent - if tree.comparator(old.Key, node.Key) >= 0 { + if tree.getComparator()(old.Key, node.Key) >= 0 { goto loop } } @@ -699,7 +703,7 @@ func (tree *RedBlackTree) output(node *RedBlackTreeNode, prefix string, isTail b func (tree *RedBlackTree) doSearch(key interface{}) (node *RedBlackTreeNode, found bool) { node = tree.root for node != nil { - compare := tree.comparator(key, node.Key) + compare := tree.getComparator()(key, node.Key) switch { case compare == 0: return node, true @@ -927,12 +931,11 @@ func (tree *RedBlackTree) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the interface UnmarshalJSON for json.Unmarshal. func (tree *RedBlackTree) UnmarshalJSON(b []byte) error { - if tree.mu == nil { - tree.mu = rwmutex.New() - tree.comparator = gutil.ComparatorString - } tree.mu.Lock() defer tree.mu.Unlock() + if tree.comparator == nil { + tree.comparator = gutil.ComparatorString + } var data map[string]interface{} if err := json.Unmarshal(b, &data); err != nil { return err @@ -945,14 +948,22 @@ func (tree *RedBlackTree) UnmarshalJSON(b []byte) error { // UnmarshalValue is an interface implement which sets any type of value for map. func (tree *RedBlackTree) UnmarshalValue(value interface{}) (err error) { - if tree.mu == nil { - tree.mu = rwmutex.New() - tree.comparator = gutil.ComparatorString - } tree.mu.Lock() defer tree.mu.Unlock() + if tree.comparator == nil { + tree.comparator = gutil.ComparatorString + } for k, v := range gconv.Map(value) { tree.doSet(k, v) } return } + +// getComparator returns the comparator if it's previously set, +// or else it panics. +func (tree *RedBlackTree) getComparator() func(a, b interface{}) int { + if tree.comparator == nil { + panic("comparator is missing for tree") + } + return tree.comparator +} diff --git a/container/gvar/gvar.go b/container/gvar/gvar.go index fb8349d29..a3e310658 100644 --- a/container/gvar/gvar.go +++ b/container/gvar/gvar.go @@ -54,11 +54,13 @@ func (v *Var) Clone() *Var { // Set sets to , and returns the old value. func (v *Var) Set(value interface{}) (old interface{}) { if v.safe { - old = v.value.(*gtype.Interface).Set(value) - } else { - old = v.value - v.value = value + if t, ok := v.value.(*gtype.Interface); ok { + old = t.Set(value) + return + } } + old = v.value + v.value = value return } @@ -68,7 +70,9 @@ func (v *Var) Val() interface{} { return nil } if v.safe { - return v.value.(*gtype.Interface).Val() + if t, ok := v.value.(*gtype.Interface); ok { + return t.Val() + } } return v.value } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 209857895..efa029572 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -11,6 +11,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/internal/intlog" "time" @@ -18,7 +19,6 @@ import ( "github.com/gogf/gf/container/gmap" "github.com/gogf/gf/container/gtype" - "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/os/gcache" "github.com/gogf/gf/util/grand" ) @@ -78,8 +78,8 @@ type DB interface { Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) // Model creation. - From(tables string) *Model - Table(tables string) *Model + Table(table ...string) *Model + Model(table ...string) *Model Schema(schema string) *Schema // Configuration methods. @@ -89,6 +89,7 @@ type DB interface { SetSchema(schema string) GetSchema() string GetPrefix() string + GetGroup() string SetDryRun(dryrun bool) GetDryRun() bool SetLogger(logger *glog.Logger) @@ -169,21 +170,23 @@ type Link interface { Prepare(sql string) (*sql.Stmt, error) } -// Value is the field value type. -type Value = *gvar.Var +type ( + // Value is the field value type. + Value = *gvar.Var -// Record is the row record of the table. -type Record map[string]Value + // Record is the row record of the table. + Record map[string]Value -// Result is the row record array. -type Result []Record + // Result is the row record array. + Result []Record -// Map is alias of map[string]interface{}, -// which is the most common usage map type. -type Map = map[string]interface{} + // Map is alias of map[string]interface{}, + // which is the most common usage map type. + Map = map[string]interface{} -// List is type of map array. -type List = []Map + // List is type of map array. + List = []Map +) const ( gINSERT_OPTION_DEFAULT = 0 diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 31e064230..891754557 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -11,6 +11,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/gogf/gf/internal/utils" "reflect" "regexp" "strings" @@ -157,7 +158,7 @@ func (c *Core) GetAll(sql string, args ...interface{}) (Result, error) { return c.DB.DoGetAll(nil, sql, args...) } -// doGetAll queries and returns data records from database. +// DoGetAll queries and returns data records from database. func (c *Core) DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) { if link == nil { link, err = c.DB.Slave() @@ -379,13 +380,15 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e // 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one; // 3: ignore: if there's unique/primary key in the data, it ignores the inserting; func (c *Core) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { - var fields []string - var values []string - var params []interface{} - var dataMap Map table = c.DB.QuotePrefixTableName(table) - reflectValue := reflect.ValueOf(data) - reflectKind := reflectValue.Kind() + var ( + fields []string + values []string + params []interface{} + dataMap Map + reflectValue = reflect.ValueOf(data) + reflectKind = reflectValue.Kind() + ) if reflectKind == reflect.Ptr { reflectValue = reflectValue.Elem() reflectKind = reflectValue.Kind() @@ -401,16 +404,23 @@ func (c *Core) DoInsert(link Link, table string, data interface{}, option int, b if len(dataMap) == 0 { return nil, errors.New("data cannot be empty") } - charL, charR := c.DB.GetChars() + var ( + charL, charR = c.DB.GetChars() + operation = GetInsertOperationByOption(option) + updateStr = "" + ) for k, v := range dataMap { fields = append(fields, charL+k+charR) values = append(values, "?") params = append(params, v) } - operation := GetInsertOperationByOption(option) - updateStr := "" if option == gINSERT_OPTION_SAVE { for k, _ := range dataMap { + // If it's SAVE operation, + // do not automatically update the creating time. + if utils.EqualFoldWithoutChars(k, gSOFT_FIELD_NAME_CREATE) { + continue + } if len(updateStr) > 0 { updateStr += "," } @@ -462,12 +472,15 @@ func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Resu return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...) } -// doBatchInsert batch inserts/replaces/saves data. +// DoBatchInsert batch inserts/replaces/saves data. func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { - var keys, values []string - var params []interface{} table = c.DB.QuotePrefixTableName(table) - listMap := (List)(nil) + var ( + keys []string + values []string + params []interface{} + listMap List + ) switch v := list.(type) { case Result: listMap = v.List() @@ -478,8 +491,10 @@ func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option i case Map: listMap = List{v} default: - rv := reflect.ValueOf(list) - kind := rv.Kind() + var ( + rv = reflect.ValueOf(list) + kind = rv.Kind() + ) if kind == reflect.Ptr { rv = rv.Elem() kind = rv.Kind() @@ -492,7 +507,7 @@ func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option i listMap[i] = DataToMapDeep(rv.Index(i).Interface()) } case reflect.Map, reflect.Struct: - listMap = List{DataToMapDeep(list)} + listMap = List{DataToMapDeep(v)} default: return result, errors.New(fmt.Sprint("unsupported list type:", kind)) } @@ -512,15 +527,21 @@ func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option i holders = append(holders, "?") } // Prepare the batch result pointer. - batchResult := new(SqlResult) - charL, charR := c.DB.GetChars() - keysStr := charL + strings.Join(keys, charR+","+charL) + charR - valueHolderStr := "(" + strings.Join(holders, ",") + ")" - - operation := GetInsertOperationByOption(option) - updateStr := "" + var ( + charL, charR = c.DB.GetChars() + batchResult = new(SqlResult) + keysStr = charL + strings.Join(keys, charR+","+charL) + charR + valueHolderStr = "(" + strings.Join(holders, ",") + ")" + operation = GetInsertOperationByOption(option) + updateStr = "" + ) if option == gINSERT_OPTION_SAVE { for _, k := range keys { + // If it's SAVE operation, + // do not automatically update the creating time. + if utils.EqualFoldWithoutChars(k, gSOFT_FIELD_NAME_CREATE) { + continue + } if len(updateStr) > 0 { updateStr += "," } @@ -599,18 +620,25 @@ func (c *Core) Update(table string, data interface{}, condition interface{}, arg // Also see Update. func (c *Core) DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { table = c.DB.QuotePrefixTableName(table) - updates := "" - rv := reflect.ValueOf(data) - kind := rv.Kind() + var ( + rv = reflect.ValueOf(data) + kind = rv.Kind() + ) if kind == reflect.Ptr { rv = rv.Elem() kind = rv.Kind() } - params := []interface{}(nil) + var ( + params []interface{} + updates = "" + ) switch kind { case reflect.Map, reflect.Struct: - var fields []string - for k, v := range DataToMapDeep(data) { + var ( + fields []string + dataMap = DataToMapDeep(data) + ) + for k, v := range dataMap { fields = append(fields, c.DB.QuoteWord(k)+"=?") params = append(params, v) } @@ -656,7 +684,7 @@ func (c *Core) Delete(table string, condition interface{}, args ...interface{}) return c.DB.DoDelete(nil, table, newWhere, newArgs...) } -// doDelete does "DELETE FROM ... " statement for the table. +// DoDelete does "DELETE FROM ... " statement for the table. // Also see Delete. func (c *Core) DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) { if link == nil { diff --git a/database/gdb/gdb_core_config.go b/database/gdb/gdb_core_config.go index 3c78665a6..9f3b2e52d 100644 --- a/database/gdb/gdb_core_config.go +++ b/database/gdb/gdb_core_config.go @@ -174,6 +174,11 @@ func (c *Core) GetPrefix() string { return c.prefix } +// GetGroup returns the group string configured. +func (c *Core) GetGroup() string { + return c.group +} + // SetDryRun enables/disables the DryRun feature. func (c *Core) SetDryRun(dryrun bool) { c.dryrun.Set(dryrun) diff --git a/database/gdb/gdb_driver_mssql.go b/database/gdb/gdb_driver_mssql.go index ead5aaca1..26aca8b98 100644 --- a/database/gdb/gdb_driver_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -13,6 +13,7 @@ package gdb import ( "database/sql" + "errors" "fmt" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" @@ -187,9 +188,10 @@ func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. func (d *DriverMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { - table = gstr.Trim(table) + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { - panic("function TableFields supports only single table operations") + return nil, errors.New("function TableFields supports only single table operations") } checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index 02f81660b..2d0dea44a 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -8,6 +8,7 @@ package gdb import ( "database/sql" + "errors" "fmt" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gregex" @@ -89,9 +90,10 @@ func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) { // // It's using cache feature to enhance the performance, which is never expired util the process restarts. func (d *DriverMysql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { - table = gstr.Trim(table) + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { - panic("function TableFields supports only single table operations") + return nil, errors.New("function TableFields supports only single table operations") } checkSchema := d.schema.Val() if len(schema) > 0 && schema[0] != "" { diff --git a/database/gdb/gdb_driver_oracle.go b/database/gdb/gdb_driver_oracle.go index d6ccd40fe..7adca6dfc 100644 --- a/database/gdb/gdb_driver_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -148,9 +148,10 @@ func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. func (d *DriverOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { - table = gstr.Trim(table) + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { - panic("function TableFields supports only single table operations") + return nil, errors.New("function TableFields supports only single table operations") } checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { diff --git a/database/gdb/gdb_driver_pgsql.go b/database/gdb/gdb_driver_pgsql.go index 1e0bc5a88..9bff38467 100644 --- a/database/gdb/gdb_driver_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -13,6 +13,7 @@ package gdb import ( "database/sql" + "errors" "fmt" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" @@ -78,7 +79,6 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { if err != nil { return nil, err } - query := "SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = 'public' ORDER BY TABLENAME" if len(schema) > 0 && schema[0] != "" { query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0]) @@ -97,9 +97,10 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. func (d *DriverPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { - table = gstr.Trim(table) + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { - panic("function TableFields supports only single table operations") + return nil, errors.New("function TableFields supports only single table operations") } table, _ = gregex.ReplaceString("\"", "", table) checkSchema := d.DB.GetSchema() diff --git a/database/gdb/gdb_driver_sqlite.go b/database/gdb/gdb_driver_sqlite.go index c433fcba3..3f9eea3b8 100644 --- a/database/gdb/gdb_driver_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -12,6 +12,7 @@ package gdb import ( "database/sql" + "errors" "fmt" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" @@ -88,11 +89,11 @@ func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. func (d *DriverSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { - table = gstr.Trim(table) + charL, charR := d.GetChars() + table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { - panic("function TableFields supports only single table operations") + return nil, errors.New("function TableFields supports only single table operations") } - checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { checkSchema = schema[0] diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index b00fefd95..f18f3951f 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -147,8 +147,13 @@ func doQuoteWord(s, charLeft, charRight string) string { } // doQuoteString quotes string with quote chars. It handles strings like: -// "user", "user u", "user,user_detail", "user u, user_detail ut", -// "user.user u, user.user_detail ut", "u.id asc". +// "user", +// "user u", +// "user,user_detail", +// "user u, user_detail ut", +// "user.user u, user.user_detail ut", +// "u.id, u.name, u.age", +// "u.id asc". func doQuoteString(s, charLeft, charRight string) string { array1 := gstr.SplitAndTrim(s, ",") for k1, v1 := range array1 { @@ -201,8 +206,13 @@ func GetPrimaryKey(pointer interface{}) string { // GetPrimaryKeyCondition returns a new where condition by primary field name. // The optional parameter is like follows: -// 123, []int{1, 2, 3}, "john", []string{"john", "smith"} -// g.Map{"id": g.Slice{1,2,3}}, g.Map{"id": 1, "name": "john"}, etc. +// 123 +// []int{1, 2, 3} +// "john" +// []string{"john", "smith"} +// g.Map{"id": g.Slice{1,2,3}} +// g.Map{"id": 1, "name": "john"} +// etc. // // Note that it returns the given parameter directly if there's the is empty. func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondition []interface{}) { @@ -443,11 +453,23 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i newArgs = append(newArgs, arg) continue } - // It converts the struct to string in default - // if it implements the String interface. - if v, ok := arg.(apiString); ok { + switch v := arg.(type) { + case time.Time, *time.Time: + newArgs = append(newArgs, arg) + continue + + // Special handling for gtime.Time. + case gtime.Time: newArgs = append(newArgs, v.String()) continue + + default: + // It converts the struct to string in default + // if it implements the String interface. + if v, ok := arg.(apiString); ok { + newArgs = append(newArgs, v.String()) + continue + } } newArgs = append(newArgs, arg) diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index df9597b0e..e212ca68b 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -7,6 +7,8 @@ package gdb import ( + "fmt" + "github.com/gogf/gf/text/gregex" "time" "github.com/gogf/gf/text/gstr" @@ -26,6 +28,7 @@ type Model struct { whereHolder []*whereHolder // Condition strings for where operation. groupBy string // Used for "group by" statement. orderBy string // Used for "order by" statement. + having []interface{} // Used for "having..." statement. start int // Used for "select ... start, limit ..." statement. limit int // Used for "select ... start, limit ..." statement. option int // Option for extra operation features. @@ -37,6 +40,7 @@ type Model struct { cacheEnabled bool // Enable sql result cache feature. cacheDuration time.Duration // Cache TTL duration. cacheName string // Cache name for custom operation. + unscoped bool // Disables soft deleting features when select/delete operations. safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. } @@ -58,70 +62,70 @@ const ( ) // Table creates and returns a new ORM model from given schema. -// The parameter can be more than one table names, like : -// "user", "user u", "user, user_detail", "user u, user_detail ud" -func (c *Core) Table(table string) *Model { - table = c.DB.QuotePrefixTableName(table) +// The parameter can be more than one table names, and also alias name, like: +// 1. Table names: +// Table("user") +// Table("user u") +// Table("user, user_detail") +// Table("user u, user_detail ud") +// 2. Table name with alias: Table("user", "u") +func (c *Core) Table(table ...string) *Model { + tables := "" + if len(table) > 1 { + tables = fmt.Sprintf( + `%s AS %s`, c.DB.QuotePrefixTableName(table[0]), c.DB.QuoteWord(table[1]), + ) + } else if len(table) == 1 { + tables = c.DB.QuotePrefixTableName(table[0]) + } else { + panic("table cannot be empty") + } return &Model{ db: c.DB, - tablesInit: table, - tables: table, + tablesInit: tables, + tables: tables, fields: "*", start: -1, offset: -1, - safe: true, option: OPTION_ALLOWEMPTY, } } // Model is alias of Core.Table. // See Core.Table. -func (c *Core) Model(table string) *Model { - return c.DB.Table(table) -} - -// From is alias of Core.Table. -// See Core.Table. -// Deprecated. -func (c *Core) From(table string) *Model { - return c.DB.Table(table) +func (c *Core) Model(table ...string) *Model { + return c.DB.Table(table...) } // Table acts like Core.Table except it operates on transaction. // See Core.Table. -func (tx *TX) Table(table string) *Model { - table = tx.db.QuotePrefixTableName(table) - return &Model{ - db: tx.db, - tx: tx, - tablesInit: table, - tables: table, - fields: "*", - start: -1, - offset: -1, - safe: true, - option: OPTION_ALLOWEMPTY, - } +func (tx *TX) Table(table ...string) *Model { + model := tx.db.Table(table...) + model.db = tx.db + model.tx = tx + return model } // Model is alias of tx.Table. // See tx.Table. -func (tx *TX) Model(table string) *Model { - return tx.Table(table) -} - -// From is alias of tx.Table. -// See tx.Table. -// Deprecated. -func (tx *TX) From(table string) *Model { - return tx.Table(table) +func (tx *TX) Model(table ...string) *Model { + return tx.Table(table...) } // As sets an alias name for current table. func (m *Model) As(as string) *Model { if m.tables != "" { model := m.getModel() - model.tables = gstr.TrimRight(model.tables) + " AS " + as + split := " JOIN " + if gstr.Contains(model.tables, split) { + // For join table. + array := gstr.Split(model.tables, split) + array[len(array)-1], _ = gregex.ReplaceString(`(.+) ON`, fmt.Sprintf(`$1 AS %s ON`, as), array[len(array)-1]) + model.tables = gstr.Join(array, split) + } else { + // For base table. + model.tables = gstr.TrimRight(model.tables) + " AS " + as + } return model } return m diff --git a/database/gdb/gdb_model_condition.go b/database/gdb/gdb_model_condition.go index cee015c08..09e2679a9 100644 --- a/database/gdb/gdb_model_condition.go +++ b/database/gdb/gdb_model_condition.go @@ -6,7 +6,10 @@ package gdb -import "github.com/gogf/gf/util/gconv" +import ( + "github.com/gogf/gf/util/gconv" + "strings" +) // Where sets the condition statement for the model. The parameter can be type of // string/map/gmap/slice/struct/*struct, etc. Note that, if it's called more than one times, @@ -32,6 +35,17 @@ func (m *Model) Where(where interface{}, args ...interface{}) *Model { return model } +// Having sets the having statement for the model. +// The parameters of this function usage are as the same as function Where. +// See Where. +func (m *Model) Having(having interface{}, args ...interface{}) *Model { + model := m.getModel() + model.having = []interface{}{ + having, args, + } + return model +} + // WherePri does the same logic as Model.Where except that if the parameter // is a single condition like int/string/float/slice, it treats the condition as the primary // key value. That is, if primary key is "id" and given parameter as "123", the @@ -87,9 +101,9 @@ func (m *Model) GroupBy(groupBy string) *Model { } // Order sets the "ORDER BY" statement for the model. -func (m *Model) Order(orderBy string) *Model { +func (m *Model) Order(orderBy ...string) *Model { model := m.getModel() - model.orderBy = m.db.QuoteString(orderBy) + model.orderBy = m.db.QuoteString(strings.Join(orderBy, " ")) return model } diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index e0638a8f7..3ae07bf41 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -8,8 +8,21 @@ package gdb import ( "database/sql" + "fmt" + "github.com/gogf/gf/os/gtime" ) +// Unscoped enables/disables the soft deleting feature. +func (m *Model) Unscoped(unscoped ...bool) *Model { + model := m.getModel() + if len(unscoped) > 0 { + model.unscoped = unscoped[0] + } else { + model.unscoped = true + } + return model +} + // Delete does "DELETE FROM ... " statement for the model. // The optional parameter is the same as the parameter of Model.Where function, // see Model.Where. @@ -22,6 +35,19 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { m.checkAndRemoveCache() } }() - condition, conditionArgs := m.formatCondition(false) - return m.db.DoDelete(m.getLink(true), m.tables, condition, conditionArgs...) + var ( + fieldNameDelete = m.getSoftFieldNameDelete() + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false) + ) + // Soft deleting. + if !m.unscoped && fieldNameDelete != "" { + return m.db.DoUpdate( + m.getLink(true), + m.tables, + fmt.Sprintf(`%s='%s'`, m.db.QuoteString(fieldNameDelete), gtime.Now().String()), + conditionWhere+conditionExtra, + conditionArgs..., + ) + } + return m.db.DoDelete(m.getLink(true), m.tables, conditionWhere+conditionExtra, conditionArgs...) } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index 255e4529a..67fc378b0 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -9,8 +9,10 @@ package gdb import ( "database/sql" "errors" + "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/gutil" "reflect" ) @@ -94,6 +96,9 @@ func (m *Model) Data(data ...interface{}) *Model { // The optional parameter is the same as the parameter of Model.Data function, // see Model.Data. func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) { + if len(data) > 0 { + return m.Data(data...).Insert() + } return m.doInsertWithOption(gINSERT_OPTION_DEFAULT, data...) } @@ -101,45 +106,10 @@ func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) { // The optional parameter is the same as the parameter of Model.Data function, // see Model.Data. func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) { - return m.doInsertWithOption(gINSERT_OPTION_IGNORE, data...) -} - -// doInsertWithOption inserts data with option parameter. -func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.Result, err error) { if len(data) > 0 { return m.Data(data...).Insert() } - defer func() { - if err == nil { - m.checkAndRemoveCache() - } - }() - if m.data == nil { - return nil, errors.New("inserting into table with empty data") - } - if list, ok := m.data.(List); ok { - // Batch insert. - batch := 10 - if m.batch > 0 { - batch = m.batch - } - return m.db.DoBatchInsert( - m.getLink(true), - m.tables, - m.filterDataForInsertOrUpdate(list), - option, - batch, - ) - } else if data, ok := m.data.(Map); ok { - // Single insert. - return m.db.DoInsert( - m.getLink(true), - m.tables, - m.filterDataForInsertOrUpdate(data), - option, - ) - } - return nil, errors.New("inserting into table with invalid data type") + return m.doInsertWithOption(gINSERT_OPTION_IGNORE, data...) } // Replace does "REPLACE INTO ..." statement for the model. @@ -149,37 +119,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) { if len(data) > 0 { return m.Data(data...).Replace() } - defer func() { - if err == nil { - m.checkAndRemoveCache() - } - }() - if m.data == nil { - return nil, errors.New("replacing into table with empty data") - } - if list, ok := m.data.(List); ok { - // Batch replace. - batch := 10 - if m.batch > 0 { - batch = m.batch - } - return m.db.DoBatchInsert( - m.getLink(true), - m.tables, - m.filterDataForInsertOrUpdate(list), - gINSERT_OPTION_REPLACE, - batch, - ) - } else if data, ok := m.data.(Map); ok { - // Single insert. - return m.db.DoInsert( - m.getLink(true), - m.tables, - m.filterDataForInsertOrUpdate(data), - gINSERT_OPTION_REPLACE, - ) - } - return nil, errors.New("replacing into table with invalid data type") + return m.doInsertWithOption(gINSERT_OPTION_REPLACE, data...) } // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model. @@ -192,35 +132,70 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) { if len(data) > 0 { return m.Data(data...).Save() } + return m.doInsertWithOption(gINSERT_OPTION_SAVE, data...) +} + +// doInsertWithOption inserts data with option parameter. +func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql.Result, err error) { defer func() { if err == nil { m.checkAndRemoveCache() } }() if m.data == nil { - return nil, errors.New("saving into table with empty data") + return nil, errors.New("inserting into table with empty data") } + var ( + nowString = gtime.Now().String() + fieldNameCreate = m.getSoftFieldNameCreate() + fieldNameUpdate = m.getSoftFieldNameUpdate() + fieldNameDelete = m.getSoftFieldNameDelete() + ) + // Batch operation. if list, ok := m.data.(List); ok { - // Batch save. batch := gDEFAULT_BATCH_NUM if m.batch > 0 { batch = m.batch } + // Automatic handling for creating/updating time. + if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") { + for k, v := range list { + gutil.MapDelete(v, fieldNameCreate, fieldNameUpdate, fieldNameDelete) + if fieldNameCreate != "" { + v[fieldNameCreate] = nowString + } + if fieldNameUpdate != "" { + v[fieldNameUpdate] = nowString + } + list[k] = v + } + } return m.db.DoBatchInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(list), - gINSERT_OPTION_SAVE, + option, batch, ) - } else if data, ok := m.data.(Map); ok { - // Single save. + } + // Single operation. + if data, ok := m.data.(Map); ok { + // Automatic handling for creating/updating time. + if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") { + gutil.MapDelete(data, fieldNameCreate, fieldNameUpdate, fieldNameDelete) + if fieldNameCreate != "" { + data[fieldNameCreate] = nowString + } + if fieldNameUpdate != "" { + data[fieldNameUpdate] = nowString + } + } return m.db.DoInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(data), - gINSERT_OPTION_SAVE, + option, ) } - return nil, errors.New("saving into table with invalid data type") + return nil, errors.New("inserting into table with invalid data type") } diff --git a/database/gdb/gdb_model_join.go b/database/gdb/gdb_model_join.go index d3d156c93..e6d818387 100644 --- a/database/gdb/gdb_model_join.go +++ b/database/gdb/gdb_model_join.go @@ -9,22 +9,70 @@ package gdb import "fmt" // LeftJoin does "LEFT JOIN ... ON ..." statement on the model. -func (m *Model) LeftJoin(table string, on string) *Model { +// The parameter
can be joined table and its joined condition, +// and also with its alias name, like: +// Table("user").LeftJoin("user_detail", "user_detail.uid=user.uid") +// Table("user", "u").LeftJoin("user_detail", "ud", "ud.uid=u.uid") +func (m *Model) LeftJoin(table ...string) *Model { model := m.getModel() - model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on) + if len(table) > 2 { + model.tables += fmt.Sprintf( + " LEFT JOIN %s AS %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), m.db.QuoteWord(table[1]), table[2], + ) + } else if len(table) == 2 { + model.tables += fmt.Sprintf( + " LEFT JOIN %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), table[1], + ) + } else { + panic("invalid join table parameter") + } return model } // RightJoin does "RIGHT JOIN ... ON ..." statement on the model. -func (m *Model) RightJoin(table string, on string) *Model { +// The parameter
can be joined table and its joined condition, +// and also with its alias name, like: +// Table("user").RightJoin("user_detail", "user_detail.uid=user.uid") +// Table("user", "u").RightJoin("user_detail", "ud", "ud.uid=u.uid") +func (m *Model) RightJoin(table ...string) *Model { model := m.getModel() - model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on) + if len(table) > 2 { + model.tables += fmt.Sprintf( + " RIGHT JOIN %s AS %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), m.db.QuoteWord(table[1]), table[2], + ) + } else if len(table) == 2 { + model.tables += fmt.Sprintf( + " RIGHT JOIN %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), table[1], + ) + } else { + panic("invalid join table parameter") + } return model } // InnerJoin does "INNER JOIN ... ON ..." statement on the model. -func (m *Model) InnerJoin(table string, on string) *Model { +// The parameter
can be joined table and its joined condition, +// and also with its alias name, like: +// Table("user").InnerJoin("user_detail", "user_detail.uid=user.uid") +// Table("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") +func (m *Model) InnerJoin(table ...string) *Model { model := m.getModel() - model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.QuotePrefixTableName(table), on) + if len(table) > 2 { + model.tables += fmt.Sprintf( + " INNER JOIN %s AS %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), m.db.QuoteWord(table[1]), table[2], + ) + } else if len(table) == 2 { + model.tables += fmt.Sprintf( + " INNER JOIN %s ON (%s)", + m.db.QuotePrefixTableName(table[0]), table[1], + ) + } else { + panic("invalid join table parameter") + } return model } diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index c3eefcb01..3f0f626d8 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -30,9 +30,25 @@ func (m *Model) All(where ...interface{}) (Result, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).All() } - condition, conditionArgs := m.formatCondition(false) + var ( + softDeletingCondition = m.getConditionForSoftDeleting() + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false) + ) + if !m.unscoped && softDeletingCondition != "" { + if conditionWhere == "" { + conditionWhere = " WHERE " + } else { + conditionWhere += " AND " + } + conditionWhere += softDeletingCondition + } return m.getAll( - fmt.Sprintf("SELECT %s FROM %s%s", m.fields, m.tables, condition), + fmt.Sprintf( + "SELECT %s FROM %s%s", + m.db.QuoteString(m.fields), + m.tables, + conditionWhere+conditionExtra, + ), conditionArgs..., ) } @@ -73,8 +89,7 @@ func (m *Model) One(where ...interface{}) (Record, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).One() } - condition, conditionArgs := m.formatCondition(true) - all, err := m.getAll(fmt.Sprintf("SELECT %s FROM %s%s", m.fields, m.tables, condition), conditionArgs...) + all, err := m.All() if err != nil { return nil, err } @@ -234,10 +249,22 @@ func (m *Model) Count(where ...interface{}) (int, error) { } countFields := "COUNT(1)" if m.fields != "" && m.fields != "*" { - countFields = fmt.Sprintf(`COUNT(%s)`, m.fields) + countFields = fmt.Sprintf(`COUNT(%s)`, m.db.QuoteString(m.fields)) } - condition, conditionArgs := m.formatCondition(false) - s := fmt.Sprintf("SELECT %s FROM %s %s", countFields, m.tables, condition) + var ( + softDeletingCondition = m.getConditionForSoftDeleting() + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false) + ) + if !m.unscoped && softDeletingCondition != "" { + if conditionWhere == "" { + conditionWhere = " WHERE " + } else { + conditionWhere += " AND " + } + conditionWhere += softDeletingCondition + } + + s := fmt.Sprintf("SELECT %s FROM %s%s", countFields, m.tables, conditionWhere+conditionExtra) if len(m.groupBy) > 0 { s = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", s) } diff --git a/database/gdb/gdb_model_time.go b/database/gdb/gdb_model_time.go new file mode 100644 index 000000000..8953b312c --- /dev/null +++ b/database/gdb/gdb_model_time.go @@ -0,0 +1,147 @@ +// Copyright 2020 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "fmt" + "github.com/gogf/gf/container/garray" + "github.com/gogf/gf/text/gregex" + "github.com/gogf/gf/text/gstr" + "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/gutil" +) + +const ( + gSOFT_FIELD_NAME_CREATE = "create_at" + gSOFT_FIELD_NAME_UPDATE = "update_at" + gSOFT_FIELD_NAME_DELETE = "delete_at" +) + +// getSoftFieldNameCreate checks and returns the field name for record creating time. +// If there's no field name for storing creating time, it returns an empty string. +// It checks the key with or without cases or chars '-'/'_'/'.'/' '. +func (m *Model) getSoftFieldNameCreate(table ...string) string { + name := "" + if len(table) > 0 { + name = table[0] + } else { + name = m.getPrimaryTableName() + } + return m.getSoftFieldName(name, gSOFT_FIELD_NAME_CREATE) +} + +// getSoftFieldNameUpdate checks and returns the field name for record updating time. +// If there's no field name for storing updating time, it returns an empty string. +// It checks the key with or without cases or chars '-'/'_'/'.'/' '. +func (m *Model) getSoftFieldNameUpdate(table ...string) (field string) { + name := "" + if len(table) > 0 { + name = table[0] + } else { + name = m.getPrimaryTableName() + } + return m.getSoftFieldName(name, gSOFT_FIELD_NAME_UPDATE) +} + +// getSoftFieldNameDelete checks and returns the field name for record deleting time. +// If there's no field name for storing deleting time, it returns an empty string. +// It checks the key with or without cases or chars '-'/'_'/'.'/' '. +func (m *Model) getSoftFieldNameDelete(table ...string) (field string) { + name := "" + if len(table) > 0 { + name = table[0] + } else { + name = m.getPrimaryTableName() + } + return m.getSoftFieldName(name, gSOFT_FIELD_NAME_DELETE) +} + +// getSoftFieldName retrieves and returns the field name of the table for possible key. +func (m *Model) getSoftFieldName(table string, key string) (field string) { + fieldsMap, _ := m.db.TableFields(table) + if len(fieldsMap) > 0 { + field, _ = gutil.MapPossibleItemByKey( + gconv.Map(fieldsMap), key, + ) + } + return +} + +// getConditionForSoftDeleting retrieves and returns the condition string for soft deleting. +// It supports multiple tables string like: +// "user u, user_detail ud" +// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid)" +// "user LEFT JOIN user_detail ON(user_detail.uid=user.uid)" +// "user u LEFT JOIN user_detail ud ON(ud.uid=u.uid) LEFT JOIN user_stats us ON(us.uid=u.uid)" +func (m *Model) getConditionForSoftDeleting() string { + if m.unscoped { + return "" + } + conditionArray := garray.NewStrArray() + if gstr.Contains(m.tables, " JOIN ") { + // Base table. + match, _ := gregex.MatchString(`(.+?) [A-Z]+ JOIN`, m.tables) + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(match[1])) + // Multiple joined tables. + matches, _ := gregex.MatchAllString(`JOIN (.+?) ON`, m.tables) + for _, match := range matches { + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(match[1])) + } + } + if conditionArray.Len() == 0 && gstr.Contains(m.tables, ",") { + // Multiple base tables. + for _, s := range gstr.SplitAndTrim(m.tables, ",") { + conditionArray.Append(m.getConditionOfTableStringForSoftDeleting(s)) + } + } + conditionArray.FilterEmpty() + if conditionArray.Len() > 0 { + return conditionArray.Join(" AND ") + } + // Only one table. + if fieldName := m.getSoftFieldNameDelete(); fieldName != "" { + return fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(fieldName)) + } + return "" +} + +// getConditionOfTableStringForSoftDeleting does something as its name describes. +func (m *Model) getConditionOfTableStringForSoftDeleting(s string) string { + var ( + field = "" + table = "" + array1 = gstr.SplitAndTrim(s, " ") + array2 = gstr.SplitAndTrim(array1[0], ".") + ) + if len(array2) >= 2 { + table = array2[1] + } else { + table = array2[0] + } + field = m.getSoftFieldNameDelete(table) + if field == "" { + return "" + } + if len(array1) >= 3 { + return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[2]), m.db.QuoteWord(field)) + } + if len(array1) >= 2 { + return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[1]), m.db.QuoteWord(field)) + } + return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(table), m.db.QuoteWord(field)) +} + +// getPrimaryTableName parses and returns the primary table name. +func (m *Model) getPrimaryTableName() string { + array1 := gstr.SplitAndTrim(m.tables, ",") + array2 := gstr.SplitAndTrim(array1[0], " ") + array3 := gstr.SplitAndTrim(array2[0], ".") + if len(array3) >= 2 { + return array3[1] + } + return array3[0] +} diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index a8cdf7019..e80995021 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -9,6 +9,12 @@ package gdb import ( "database/sql" "errors" + "fmt" + "github.com/gogf/gf/os/gtime" + "github.com/gogf/gf/text/gstr" + "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/gutil" + "reflect" ) // Update does "UPDATE ... " statement for the model. @@ -34,12 +40,44 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro if m.data == nil { return nil, errors.New("updating table with empty data") } - condition, conditionArgs := m.formatCondition(false) + var ( + updateData = m.data + fieldNameCreate = m.getSoftFieldNameCreate() + fieldNameUpdate = m.getSoftFieldNameUpdate() + fieldNameDelete = m.getSoftFieldNameDelete() + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false) + ) + // Automatically update the record updating time. + if !m.unscoped && fieldNameUpdate != "" { + var ( + refValue = reflect.ValueOf(m.data) + refKind = refValue.Kind() + ) + if refKind == reflect.Ptr { + refValue = refValue.Elem() + refKind = refValue.Kind() + } + switch refKind { + case reflect.Map, reflect.Struct: + dataMap := DataToMapDeep(m.data) + gutil.MapDelete(dataMap, fieldNameCreate, fieldNameUpdate, fieldNameDelete) + if fieldNameUpdate != "" { + dataMap[fieldNameUpdate] = gtime.Now().String() + } + updateData = dataMap + default: + updates := gconv.String(m.data) + if fieldNameUpdate != "" && !gstr.Contains(updates, fieldNameUpdate) { + updates += fmt.Sprintf(`,%s='%s'`, fieldNameUpdate, gtime.Now().String()) + } + updateData = updates + } + } return m.db.DoUpdate( m.getLink(true), m.tables, - m.filterDataForInsertOrUpdate(m.data), - condition, + m.filterDataForInsertOrUpdate(updateData), + conditionWhere+conditionExtra, m.mergeArguments(conditionArgs)..., ) } diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index 7acc7063b..cf41a44dd 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -12,6 +12,7 @@ import ( "github.com/gogf/gf/internal/empty" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gstr" + "github.com/gogf/gf/util/gconv" "time" ) @@ -28,15 +29,19 @@ func (m *Model) getModel() *Model { // filterDataForInsertOrUpdate does filter feature with data for inserting/updating operations. // Note that, it does not filter list item, which is also type of map, for "omit empty" feature. func (m *Model) filterDataForInsertOrUpdate(data interface{}) interface{} { - if list, ok := m.data.(List); ok { - for k, item := range list { - list[k] = m.doFilterDataMapForInsertOrUpdate(item, false) + switch value := data.(type) { + case List: + for k, item := range value { + value[k] = m.doFilterDataMapForInsertOrUpdate(item, false) } - return list - } else if item, ok := m.data.(Map); ok { - return m.doFilterDataMapForInsertOrUpdate(item, true) + return value + + case Map: + return m.doFilterDataMapForInsertOrUpdate(value, true) + + default: + return data } - return data } // doFilterDataMapForInsertOrUpdate does the filter features for map. @@ -70,8 +75,13 @@ func (m *Model) doFilterDataMapForInsertOrUpdate(data Map, allowOmitEmpty bool) if len(m.fields) > 0 && m.fields != "*" { // Keep specified fields. - set := gset.NewStrSetFrom(gstr.SplitAndTrim(m.fields, ",")) + var ( + set = gset.NewStrSetFrom(gstr.SplitAndTrim(m.fields, ",")) + charL, charR = m.db.GetChars() + chars = charL + charR + ) for k := range data { + k = gstr.Trim(k, chars) if !set.Contains(k) { delete(data, k) } @@ -144,16 +154,17 @@ func (m *Model) checkAndRemoveCache() { // Note that this function does not change any attribute value of the . // // The parameter specifies whether limits querying only one record if m.limit is not set. -func (m *Model) formatCondition(limit bool) (condition string, conditionArgs []interface{}) { - var where string +func (m *Model) formatCondition(limit bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) { if len(m.whereHolder) > 0 { for _, v := range m.whereHolder { switch v.operator { case gWHERE_HOLDER_WHERE: - if where == "" { - newWhere, newArgs := formatWhere(m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0) + if conditionWhere == "" { + newWhere, newArgs := formatWhere( + m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0, + ) if len(newWhere) > 0 { - where = newWhere + conditionWhere = newWhere conditionArgs = newArgs } continue @@ -161,52 +172,69 @@ func (m *Model) formatCondition(limit bool) (condition string, conditionArgs []i fallthrough case gWHERE_HOLDER_AND: - newWhere, newArgs := formatWhere(m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0) + newWhere, newArgs := formatWhere( + m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0, + ) if len(newWhere) > 0 { - if where[0] == '(' { - where = fmt.Sprintf(`%s AND (%s)`, where, newWhere) + if len(conditionWhere) == 0 { + conditionWhere = newWhere + } else if conditionWhere[0] == '(' { + conditionWhere = fmt.Sprintf(`%s AND (%s)`, conditionWhere, newWhere) } else { - where = fmt.Sprintf(`(%s) AND (%s)`, where, newWhere) + conditionWhere = fmt.Sprintf(`(%s) AND (%s)`, conditionWhere, newWhere) } conditionArgs = append(conditionArgs, newArgs...) } case gWHERE_HOLDER_OR: - newWhere, newArgs := formatWhere(m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0) + newWhere, newArgs := formatWhere( + m.db, v.where, v.args, m.option&OPTION_OMITEMPTY > 0, + ) if len(newWhere) > 0 { - if where[0] == '(' { - where = fmt.Sprintf(`%s OR (%s)`, where, newWhere) + if len(conditionWhere) == 0 { + conditionWhere = newWhere + } else if conditionWhere[0] == '(' { + conditionWhere = fmt.Sprintf(`%s OR (%s)`, conditionWhere, newWhere) } else { - where = fmt.Sprintf(`(%s) OR (%s)`, where, newWhere) + conditionWhere = fmt.Sprintf(`(%s) OR (%s)`, conditionWhere, newWhere) } conditionArgs = append(conditionArgs, newArgs...) } } } } - if where != "" { - condition += " WHERE " + where + if conditionWhere != "" { + conditionWhere = " WHERE " + conditionWhere } if m.groupBy != "" { - condition += " GROUP BY " + m.groupBy + conditionExtra += " GROUP BY " + m.groupBy } if m.orderBy != "" { - condition += " ORDER BY " + m.orderBy + conditionExtra += " ORDER BY " + m.orderBy + } + if len(m.having) > 0 { + havingStr, havingArgs := formatWhere( + m.db, m.having[0], gconv.Interfaces(m.having[1]), m.option&OPTION_OMITEMPTY > 0, + ) + if len(havingStr) > 0 { + conditionExtra += " HAVING " + havingStr + conditionArgs = append(conditionArgs, havingArgs...) + } } if m.limit != 0 { if m.start >= 0 { - condition += fmt.Sprintf(" LIMIT %d,%d", m.start, m.limit) + conditionExtra += fmt.Sprintf(" LIMIT %d,%d", m.start, m.limit) } else { - condition += fmt.Sprintf(" LIMIT %d", m.limit) + conditionExtra += fmt.Sprintf(" LIMIT %d", m.limit) } } else if limit { - condition += " LIMIT 1" + conditionExtra += " LIMIT 1" } if m.offset >= 0 { - condition += fmt.Sprintf(" OFFSET %d", m.offset) + conditionExtra += fmt.Sprintf(" OFFSET %d", m.offset) } if m.lockInfo != "" { - condition += " " + m.lockInfo + conditionExtra += " " + m.lockInfo } return } diff --git a/database/gdb/gdb_schema.go b/database/gdb/gdb_schema.go index 875869a55..6671fcca8 100644 --- a/database/gdb/gdb_schema.go +++ b/database/gdb/gdb_schema.go @@ -40,6 +40,14 @@ func (s *Schema) Table(table string) *Model { } else { m = s.db.Table(table) } + // Do not change the schema of the original db, + // it here creates a new db and changes its schema. + db, err := New(m.db.GetGroup()) + if err != nil { + panic(err) + } + db.SetSchema(s.schema) + m.db = db m.schema = s.schema return m } diff --git a/database/gdb/gdb_unit_z_func_test.go b/database/gdb/gdb_unit_z_func_test.go deleted file mode 100644 index 3e6941c97..000000000 --- a/database/gdb/gdb_unit_z_func_test.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package gdb - -import ( - "github.com/gogf/gf/test/gtest" - "testing" -) - -func Test_Func_FormatSqlWithArgs(t *testing.T) { - // mysql - gtest.C(t, func(t *gtest.T) { - var s string - s = FormatSqlWithArgs("select * from table where id>=? and sex=?", []interface{}{100, 1}) - t.Assert(s, "select * from table where id>=100 and sex=1") - }) - // mssql - gtest.C(t, func(t *gtest.T) { - var s string - s = FormatSqlWithArgs("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1}) - t.Assert(s, "select * from table where id>=100 and sex=1") - }) - // pgsql - gtest.C(t, func(t *gtest.T) { - var s string - s = FormatSqlWithArgs("select * from table where id>=$1 and sex=$2", []interface{}{100, 1}) - t.Assert(s, "select * from table where id>=100 and sex=1") - }) - // oracle - gtest.C(t, func(t *gtest.T) { - var s string - s = FormatSqlWithArgs("select * from table where id>=:1 and sex=:2", []interface{}{100, 1}) - t.Assert(s, "select * from table where id>=100 and sex=1") - }) -} - -func Test_Func_doQuoteWord(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - array := map[string]string{ - "user": "`user`", - "user u": "user u", - "user_detail": "`user_detail`", - "user,user_detail": "user,user_detail", - "user u, user_detail ut": "user u, user_detail ut", - "u.id asc": "u.id asc", - "u.id asc, ut.uid desc": "u.id asc, ut.uid desc", - } - for k, v := range array { - t.Assert(doQuoteWord(k, "`", "`"), v) - } - }) -} - -func Test_Func_doQuoteString(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - // "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc". - array := map[string]string{ - "user": "`user`", - "user u": "`user` u", - "user,user_detail": "`user`,`user_detail`", - "user u, user_detail ut": "`user` u,`user_detail` ut", - "u.id asc": "`u`.`id` asc", - "u.id asc, ut.uid desc": "`u`.`id` asc,`ut`.`uid` desc", - "user.user u, user.user_detail ut": "`user`.`user` u,`user`.`user_detail` ut", - // mssql global schema access with double dots. - "user..user u, user.user_detail ut": "`user`..`user` u,`user`.`user_detail` ut", - } - for k, v := range array { - t.Assert(doQuoteString(k, "`", "`"), v) - } - }) -} - -func Test_Func_addTablePrefix(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - prefix := "" - array := map[string]string{ - "user": "`user`", - "user u": "`user` u", - "user as u": "`user` as u", - "user,user_detail": "`user`,`user_detail`", - "user u, user_detail ut": "`user` u,`user_detail` ut", - "`user`.user_detail": "`user`.`user_detail`", - "`user`.`user_detail`": "`user`.`user_detail`", - "user as u, user_detail as ut": "`user` as u,`user_detail` as ut", - "UserCenter.user as u, UserCenter.user_detail as ut": "`UserCenter`.`user` as u,`UserCenter`.`user_detail` as ut", - // mssql global schema access with double dots. - "UserCenter..user as u, user_detail as ut": "`UserCenter`..`user` as u,`user_detail` as ut", - } - for k, v := range array { - t.Assert(doHandleTableName(k, prefix, "`", "`"), v) - } - }) - gtest.C(t, func(t *gtest.T) { - prefix := "gf_" - array := map[string]string{ - "user": "`gf_user`", - "user u": "`gf_user` u", - "user as u": "`gf_user` as u", - "user,user_detail": "`gf_user`,`gf_user_detail`", - "user u, user_detail ut": "`gf_user` u,`gf_user_detail` ut", - "`user`.user_detail": "`user`.`gf_user_detail`", - "`user`.`user_detail`": "`user`.`gf_user_detail`", - "user as u, user_detail as ut": "`gf_user` as u,`gf_user_detail` as ut", - "UserCenter.user as u, UserCenter.user_detail as ut": "`UserCenter`.`gf_user` as u,`UserCenter`.`gf_user_detail` as ut", - // mssql global schema access with double dots. - "UserCenter..user as u, user_detail as ut": "`UserCenter`..`gf_user` as u,`gf_user_detail` as ut", - } - for k, v := range array { - t.Assert(doHandleTableName(k, prefix, "`", "`"), v) - } - }) -} diff --git a/database/gdb/gdb_unit_z_mysql_internal_test.go b/database/gdb/gdb_unit_z_mysql_internal_test.go new file mode 100644 index 000000000..a25d5391c --- /dev/null +++ b/database/gdb/gdb_unit_z_mysql_internal_test.go @@ -0,0 +1,289 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "fmt" + "github.com/gogf/gf/os/gcmd" + "github.com/gogf/gf/os/gtime" + "github.com/gogf/gf/test/gtest" + "testing" +) + +const ( + SCHEMA = "test_internal" +) + +var ( + db DB + configNode ConfigNode +) + +func init() { + parser, err := gcmd.Parse(map[string]bool{ + "name": true, + "type": true, + }, false) + gtest.Assert(err, nil) + configNode = ConfigNode{ + Host: "127.0.0.1", + Port: "3306", + User: "root", + Pass: "12345678", + Name: parser.GetOpt("name", ""), + Type: parser.GetOpt("type", "mysql"), + Role: "master", + Charset: "utf8", + Weight: 1, + MaxIdleConnCount: 10, + MaxOpenConnCount: 10, + MaxConnLifetime: 600, + } + AddConfigNode(DEFAULT_GROUP_NAME, configNode) + // Default db. + if r, err := New(); err != nil { + gtest.Error(err) + } else { + db = r + } + schemaTemplate := "CREATE DATABASE IF NOT EXISTS `%s` CHARACTER SET UTF8" + if _, err := db.Exec(fmt.Sprintf(schemaTemplate, SCHEMA)); err != nil { + gtest.Error(err) + } + db.SetSchema(SCHEMA) +} + +func dropTable(table string) { + if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil { + gtest.Error(err) + } +} + +func Test_Func_FormatSqlWithArgs(t *testing.T) { + // mysql + gtest.C(t, func(t *gtest.T) { + var s string + s = FormatSqlWithArgs("select * from table where id>=? and sex=?", []interface{}{100, 1}) + t.Assert(s, "select * from table where id>=100 and sex=1") + }) + // mssql + gtest.C(t, func(t *gtest.T) { + var s string + s = FormatSqlWithArgs("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1}) + t.Assert(s, "select * from table where id>=100 and sex=1") + }) + // pgsql + gtest.C(t, func(t *gtest.T) { + var s string + s = FormatSqlWithArgs("select * from table where id>=$1 and sex=$2", []interface{}{100, 1}) + t.Assert(s, "select * from table where id>=100 and sex=1") + }) + // oracle + gtest.C(t, func(t *gtest.T) { + var s string + s = FormatSqlWithArgs("select * from table where id>=:1 and sex=:2", []interface{}{100, 1}) + t.Assert(s, "select * from table where id>=100 and sex=1") + }) +} + +func Test_Func_doQuoteWord(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := map[string]string{ + "user": "`user`", + "user u": "user u", + "user_detail": "`user_detail`", + "user,user_detail": "user,user_detail", + "user u, user_detail ut": "user u, user_detail ut", + "u.id asc": "u.id asc", + "u.id asc, ut.uid desc": "u.id asc, ut.uid desc", + } + for k, v := range array { + t.Assert(doQuoteWord(k, "`", "`"), v) + } + }) +} + +func Test_Func_doQuoteString(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + array := map[string]string{ + "user": "`user`", + "user u": "`user` u", + "user,user_detail": "`user`,`user_detail`", + "user u, user_detail ut": "`user` u,`user_detail` ut", + "u.id, u.name, u.age": "`u`.`id`,`u`.`name`,`u`.`age`", + "u.id asc": "`u`.`id` asc", + "u.id asc, ut.uid desc": "`u`.`id` asc,`ut`.`uid` desc", + "user.user u, user.user_detail ut": "`user`.`user` u,`user`.`user_detail` ut", + // mssql global schema access with double dots. + "user..user u, user.user_detail ut": "`user`..`user` u,`user`.`user_detail` ut", + } + for k, v := range array { + t.Assert(doQuoteString(k, "`", "`"), v) + } + }) +} + +func Test_Func_addTablePrefix(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + prefix := "" + array := map[string]string{ + "user": "`user`", + "user u": "`user` u", + "user as u": "`user` as u", + "user,user_detail": "`user`,`user_detail`", + "user u, user_detail ut": "`user` u,`user_detail` ut", + "`user`.user_detail": "`user`.`user_detail`", + "`user`.`user_detail`": "`user`.`user_detail`", + "user as u, user_detail as ut": "`user` as u,`user_detail` as ut", + "UserCenter.user as u, UserCenter.user_detail as ut": "`UserCenter`.`user` as u,`UserCenter`.`user_detail` as ut", + // mssql global schema access with double dots. + "UserCenter..user as u, user_detail as ut": "`UserCenter`..`user` as u,`user_detail` as ut", + } + for k, v := range array { + t.Assert(doHandleTableName(k, prefix, "`", "`"), v) + } + }) + gtest.C(t, func(t *gtest.T) { + prefix := "gf_" + array := map[string]string{ + "user": "`gf_user`", + "user u": "`gf_user` u", + "user as u": "`gf_user` as u", + "user,user_detail": "`gf_user`,`gf_user_detail`", + "user u, user_detail ut": "`gf_user` u,`gf_user_detail` ut", + "`user`.user_detail": "`user`.`gf_user_detail`", + "`user`.`user_detail`": "`user`.`gf_user_detail`", + "user as u, user_detail as ut": "`gf_user` as u,`gf_user_detail` as ut", + "UserCenter.user as u, UserCenter.user_detail as ut": "`UserCenter`.`gf_user` as u,`UserCenter`.`gf_user_detail` as ut", + // mssql global schema access with double dots. + "UserCenter..user as u, user_detail as ut": "`UserCenter`..`gf_user` as u,`gf_user_detail` as ut", + } + for k, v := range array { + t.Assert(doHandleTableName(k, prefix, "`", "`"), v) + } + }) +} + +func Test_Model_getSoftFieldName(t *testing.T) { + table1 := "soft_deleting_table_" + gtime.TimestampNanoStr() + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + create_at datetime DEFAULT NULL, + update_at datetime DEFAULT NULL, + delete_at datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table1)); err != nil { + gtest.Error(err) + } + defer dropTable(table1) + + table2 := "soft_deleting_table_" + gtime.TimestampNanoStr() + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + createat datetime DEFAULT NULL, + updateat datetime DEFAULT NULL, + deleteat datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table2)); err != nil { + gtest.Error(err) + } + defer dropTable(table2) + + gtest.C(t, func(t *gtest.T) { + model := db.Table(table1) + gtest.Assert(model.getSoftFieldNameCreate(table2), "createat") + gtest.Assert(model.getSoftFieldNameUpdate(table2), "updateat") + gtest.Assert(model.getSoftFieldNameDelete(table2), "deleteat") + }) +} + +func Test_Model_getConditionForSoftDeleting(t *testing.T) { + table1 := "soft_deleting_table_" + gtime.TimestampNanoStr() + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id1 int(11) NOT NULL, + name1 varchar(45) DEFAULT NULL, + create_at datetime DEFAULT NULL, + update_at datetime DEFAULT NULL, + delete_at datetime DEFAULT NULL, + PRIMARY KEY (id1) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table1)); err != nil { + gtest.Error(err) + } + defer dropTable(table1) + + table2 := "soft_deleting_table_" + gtime.TimestampNanoStr() + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id2 int(11) NOT NULL, + name2 varchar(45) DEFAULT NULL, + createat datetime DEFAULT NULL, + updateat datetime DEFAULT NULL, + deleteat datetime DEFAULT NULL, + PRIMARY KEY (id2) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table2)); err != nil { + gtest.Error(err) + } + defer dropTable(table2) + + gtest.C(t, func(t *gtest.T) { + model := db.Table(table1) + t.Assert(model.getConditionForSoftDeleting(), "`delete_at` IS NULL") + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s as t`, table1)) + t.Assert(model.getConditionForSoftDeleting(), "`delete_at` IS NULL") + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s, %s`, table1, table2)) + t.Assert(model.getConditionForSoftDeleting(), fmt.Sprintf( + "`%s`.`delete_at` IS NULL AND `%s`.`deleteat` IS NULL", + table1, table2, + )) + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s t1, %s as t2`, table1, table2)) + t.Assert(model.getConditionForSoftDeleting(), "`t1`.`delete_at` IS NULL AND `t2`.`deleteat` IS NULL") + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s as t1, %s as t2`, table1, table2)) + t.Assert(model.getConditionForSoftDeleting(), "`t1`.`delete_at` IS NULL AND `t2`.`deleteat` IS NULL") + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s as t1`, table1)).LeftJoin(table2+" t2", "t2.id2=t1.id1") + t.Assert(model.getConditionForSoftDeleting(), "`t1`.`delete_at` IS NULL AND `t2`.`deleteat` IS NULL") + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s`, table1)).LeftJoin(table2, "t2.id2=t1.id1") + t.Assert(model.getConditionForSoftDeleting(), fmt.Sprintf( + "`%s`.`delete_at` IS NULL AND `%s`.`deleteat` IS NULL", + table1, table2, + )) + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(fmt.Sprintf(`%s`, table1)).LeftJoin(table2, "t2.id2=t1.id1").RightJoin(table2, "t2.id2=t1.id1") + t.Assert(model.getConditionForSoftDeleting(), fmt.Sprintf( + "`%s`.`delete_at` IS NULL AND `%s`.`deleteat` IS NULL AND `%s`.`deleteat` IS NULL", + table1, table2, table2, + )) + }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(table1+" as t1").LeftJoin(table2+" as t2", "t2.id2=t1.id1").RightJoin(table2+" as t3 ", "t2.id2=t1.id1") + t.Assert( + model.getConditionForSoftDeleting(), + "`t1`.`delete_at` IS NULL AND `t2`.`deleteat` IS NULL AND `t3`.`deleteat` IS NULL", + ) + }) +} diff --git a/database/gdb/gdb_unit_z_mysql_model_test.go b/database/gdb/gdb_unit_z_mysql_model_test.go index 7f2032acf..3e8f3ac7d 100644 --- a/database/gdb/gdb_unit_z_mysql_model_test.go +++ b/database/gdb/gdb_unit_z_mysql_model_test.go @@ -26,7 +26,7 @@ func Test_Model_Insert(t *testing.T) { table := createTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { - user := db.From(table) + user := db.Table(table) result, err := user.Filter().Data(g.Map{ "id": 1, "uid": 1, @@ -338,7 +338,7 @@ func Test_Model_Safe(t *testing.T) { t.Assert(err, nil) t.Assert(len(all), 2) - all, err = md2.ForPage(1, 10).All() + all, err = md2.Page(1, 10).All() t.Assert(err, nil) t.Assert(len(all), 2) }) @@ -362,7 +362,7 @@ func Test_Model_Safe(t *testing.T) { t.Assert(all[0]["id"].Int(), 1) t.Assert(all[1]["id"].Int(), 3) - all, err = md2.ForPage(1, 10).All() + all, err = md2.Page(1, 10).All() t.Assert(err, nil) t.Assert(len(all), 2) @@ -378,7 +378,7 @@ func Test_Model_Safe(t *testing.T) { t.Assert(all[1]["id"].Int(), 5) t.Assert(all[2]["id"].Int(), 6) - all, err = md3.ForPage(1, 10).All() + all, err = md3.Page(1, 10).All() t.Assert(err, nil) t.Assert(len(all), 3) }) @@ -1274,6 +1274,11 @@ func Test_Model_Where_GTime(t *testing.T) { t.Assert(err, nil) t.Assert(len(result), 10) }) + gtest.C(t, func(t *gtest.T) { + result, err := db.Table(table).Where("create_time>?", *gtime.NewFromStr("2010-09-01")).All() + t.Assert(err, nil) + t.Assert(len(result), 10) + }) } func Test_Model_WherePri(t *testing.T) { @@ -1566,7 +1571,7 @@ func Test_Model_Offset(t *testing.T) { }) } -func Test_Model_ForPage(t *testing.T) { +func Test_Model_Page(t *testing.T) { table := createInitTable() defer dropTable(table) gtest.C(t, func(t *gtest.T) { @@ -1576,6 +1581,15 @@ func Test_Model_ForPage(t *testing.T) { t.Assert(result[0]["id"], 7) t.Assert(result[1]["id"], 8) }) + gtest.C(t, func(t *gtest.T) { + model := db.Table(table).Safe().Order("id") + all, err := model.Page(3, 3).All() + count, err := model.Count() + t.Assert(err, nil) + t.Assert(len(all), 3) + t.Assert(all[0]["id"], "7") + t.Assert(count, SIZE) + }) } func Test_Model_Option_Map(t *testing.T) { @@ -2158,3 +2172,71 @@ func Test_Model_DryRun(t *testing.T) { t.Assert(n, 0) }) } + +func Test_Model_Cache(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + one, err := db.Table(table).Cache(time.Second, "test1").FindOne(1) + t.Assert(err, nil) + t.Assert(one["passport"], "user_1") + + r, err := db.Table(table).Data("passport", "user_100").WherePri(1).Update() + t.Assert(err, nil) + n, err := r.RowsAffected() + t.Assert(err, nil) + t.Assert(n, 1) + + one, err = db.Table(table).Cache(time.Second, "test1").FindOne(1) + t.Assert(err, nil) + t.Assert(one["passport"], "user_1") + + time.Sleep(time.Second * 2) + + one, err = db.Table(table).Cache(time.Second, "test1").FindOne(1) + t.Assert(err, nil) + t.Assert(one["passport"], "user_100") + }) + gtest.C(t, func(t *gtest.T) { + one, err := db.Table(table).Cache(time.Second, "test2").FindOne(2) + t.Assert(err, nil) + t.Assert(one["passport"], "user_2") + + r, err := db.Table(table).Data("passport", "user_200").Cache(-1, "test2").WherePri(2).Update() + t.Assert(err, nil) + n, err := r.RowsAffected() + t.Assert(err, nil) + t.Assert(n, 1) + + one, err = db.Table(table).Cache(time.Second, "test2").FindOne(2) + t.Assert(err, nil) + t.Assert(one["passport"], "user_200") + }) +} + +func Test_Model_Having(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + all, err := db.Table(table).Where("id > 1").Having("id > 8").All() + t.Assert(err, nil) + t.Assert(len(all), 2) + }) + gtest.C(t, func(t *gtest.T) { + all, err := db.Table(table).Where("id > 1").Having("id > ?", 8).All() + t.Assert(err, nil) + t.Assert(len(all), 2) + }) + gtest.C(t, func(t *gtest.T) { + all, err := db.Table(table).Where("id > ?", 1).Having("id > ?", 8).All() + t.Assert(err, nil) + t.Assert(len(all), 2) + }) + gtest.C(t, func(t *gtest.T) { + all, err := db.Table(table).Where("id > ?", 1).Having("id", 8).All() + t.Assert(err, nil) + t.Assert(len(all), 1) + }) +} diff --git a/database/gdb/gdb_unit_z_mysql_time_test.go b/database/gdb/gdb_unit_z_mysql_time_test.go new file mode 100644 index 000000000..8475f6fd7 --- /dev/null +++ b/database/gdb/gdb_unit_z_mysql_time_test.go @@ -0,0 +1,370 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb_test + +import ( + "fmt" + "github.com/gogf/gf/os/gtime" + "testing" + "time" + + "github.com/gogf/gf/frame/g" + + "github.com/gogf/gf/test/gtest" +) + +func Test_CreateUpdateDeleteTime(t *testing.T) { + table := "time_test_table" + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + create_at datetime DEFAULT NULL, + update_at datetime DEFAULT NULL, + delete_at datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table)); err != nil { + gtest.Error(err) + } + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + // Insert + dataInsert := g.Map{ + "id": 1, + "name": "name_1", + } + r, err := db.Table(table).Data(dataInsert).Insert() + t.Assert(err, nil) + n, _ := r.RowsAffected() + t.Assert(n, 1) + + oneInsert, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneInsert["id"].Int(), 1) + t.Assert(oneInsert["name"].String(), "name_1") + t.Assert(oneInsert["delete_at"].String(), "") + t.AssertGE(oneInsert["create_at"].GTime().Timestamp(), gtime.Timestamp()-2) + t.AssertGE(oneInsert["update_at"].GTime().Timestamp(), gtime.Timestamp()) + + time.Sleep(2 * time.Second) + + // Save + dataSave := g.Map{ + "id": 1, + "name": "name_10", + } + r, err = db.Table(table).Data(dataSave).Save() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 2) + + oneSave, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneSave["id"].Int(), 1) + t.Assert(oneSave["name"].String(), "name_10") + t.Assert(oneSave["delete_at"].String(), "") + t.Assert(oneSave["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertNE(oneSave["update_at"].GTime().Timestamp(), oneInsert["update_at"].GTime().Timestamp()) + t.AssertGE(oneSave["update_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + + time.Sleep(2 * time.Second) + + // Update + dataUpdate := g.Map{ + "id": 1, + "name": "name_1000", + } + r, err = db.Table(table).Data(dataUpdate).WherePri(1).Update() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + + oneUpdate, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneUpdate["id"].Int(), 1) + t.Assert(oneUpdate["name"].String(), "name_1000") + t.Assert(oneUpdate["delete_at"].String(), "") + t.Assert(oneUpdate["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertGE(oneUpdate["update_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + + // Replace + dataReplace := g.Map{ + "id": 1, + "name": "name_100", + } + r, err = db.Table(table).Data(dataReplace).Replace() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 2) + + oneReplace, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneReplace["id"].Int(), 1) + t.Assert(oneReplace["name"].String(), "name_100") + t.Assert(oneReplace["delete_at"].String(), "") + t.AssertGE(oneReplace["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertGE(oneReplace["update_at"].GTime().Timestamp(), oneInsert["update_at"].GTime().Timestamp()) + + time.Sleep(2 * time.Second) + + // Delete + r, err = db.Table(table).Delete("id", 1) + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + // Delete Select + one4, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(len(one4), 0) + one5, err := db.Table(table).Unscoped().FindOne(1) + t.Assert(err, nil) + t.Assert(one5["id"].Int(), 1) + t.AssertGE(one5["delete_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + // Delete Count + i, err := db.Table(table).FindCount() + t.Assert(err, nil) + t.Assert(i, 0) + i, err = db.Table(table).Unscoped().FindCount() + t.Assert(err, nil) + t.Assert(i, 1) + + // Delete Unscoped + r, err = db.Table(table).Unscoped().Delete("id", 1) + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + one6, err := db.Table(table).Unscoped().FindOne(1) + t.Assert(err, nil) + t.Assert(len(one6), 0) + i, err = db.Table(table).Unscoped().FindCount() + t.Assert(err, nil) + t.Assert(i, 0) + }) +} + +func Test_SoftDelete_Join(t *testing.T) { + table1 := "time_test_table1" + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + create_at datetime DEFAULT NULL, + update_at datetime DEFAULT NULL, + delete_at datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table1)); err != nil { + gtest.Error(err) + } + defer dropTable(table1) + + table2 := "time_test_table2" + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + createat datetime DEFAULT NULL, + updateat datetime DEFAULT NULL, + deleteat datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table2)); err != nil { + gtest.Error(err) + } + defer dropTable(table2) + + gtest.C(t, func(t *gtest.T) { + //db.SetDebug(true) + dataInsert1 := g.Map{ + "id": 1, + "name": "name_1", + } + r, err := db.Table(table1).Data(dataInsert1).Insert() + t.Assert(err, nil) + n, _ := r.RowsAffected() + t.Assert(n, 1) + + dataInsert2 := g.Map{ + "id": 1, + "name": "name_2", + } + r, err = db.Table(table2).Data(dataInsert2).Insert() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + + one, err := db.Table(table1, "t1").LeftJoin(table2, "t2", "t2.id=t1.id").Fields("t1.name").FindOne() + t.Assert(err, nil) + t.Assert(one["name"], "name_1") + + // Soft deleting. + r, err = db.Table(table1).Delete() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + + one, err = db.Table(table1, "t1").LeftJoin(table2, "t2", "t2.id=t1.id").Fields("t1.name").FindOne() + t.Assert(err, nil) + t.Assert(one.IsEmpty(), true) + + one, err = db.Table(table2, "t2").LeftJoin(table1, "t1", "t2.id=t1.id").Fields("t2.name").FindOne() + t.Assert(err, nil) + t.Assert(one.IsEmpty(), true) + }) +} + +func Test_CreateUpdateTime_Struct(t *testing.T) { + table := "time_test_table" + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(11) NOT NULL, + name varchar(45) DEFAULT NULL, + create_at datetime DEFAULT NULL, + update_at datetime DEFAULT NULL, + delete_at datetime DEFAULT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table)); err != nil { + gtest.Error(err) + } + defer dropTable(table) + + type Entity struct { + Id uint64 `orm:"id,primary" json:"id"` + Name string `orm:"name" json:"name"` + CreateAt *gtime.Time `orm:"create_at" json:"create_at"` + UpdateAt *gtime.Time `orm:"update_at" json:"update_at"` + DeleteAt *gtime.Time `orm:"delete_at" json:"delete_at"` + } + gtest.C(t, func(t *gtest.T) { + // Insert + dataInsert := &Entity{ + Id: 1, + Name: "name_1", + CreateAt: nil, + UpdateAt: nil, + DeleteAt: nil, + } + r, err := db.Table(table).Data(dataInsert).Insert() + t.Assert(err, nil) + n, _ := r.RowsAffected() + t.Assert(n, 1) + + oneInsert, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneInsert["id"].Int(), 1) + t.Assert(oneInsert["name"].String(), "name_1") + t.Assert(oneInsert["delete_at"].String(), "") + t.AssertGE(oneInsert["create_at"].GTime().Timestamp(), gtime.Timestamp()-2) + t.AssertGE(oneInsert["update_at"].GTime().Timestamp(), gtime.Timestamp()) + + time.Sleep(2 * time.Second) + + // Save + dataSave := &Entity{ + Id: 1, + Name: "name_10", + CreateAt: nil, + UpdateAt: nil, + DeleteAt: nil, + } + r, err = db.Table(table).Data(dataSave).Save() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 2) + + oneSave, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneSave["id"].Int(), 1) + t.Assert(oneSave["name"].String(), "name_10") + t.Assert(oneSave["delete_at"].String(), "") + t.Assert(oneSave["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertNE(oneSave["update_at"].GTime().Timestamp(), oneInsert["update_at"].GTime().Timestamp()) + t.AssertGE(oneSave["update_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + + time.Sleep(2 * time.Second) + + // Update + dataUpdate := &Entity{ + Id: 1, + Name: "name_1000", + CreateAt: nil, + UpdateAt: nil, + DeleteAt: nil, + } + r, err = db.Table(table).Data(dataUpdate).WherePri(1).Update() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + + oneUpdate, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneUpdate["id"].Int(), 1) + t.Assert(oneUpdate["name"].String(), "name_1000") + t.Assert(oneUpdate["delete_at"].String(), "") + t.Assert(oneUpdate["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertGE(oneUpdate["update_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + + // Replace + dataReplace := &Entity{ + Id: 1, + Name: "name_100", + CreateAt: nil, + UpdateAt: nil, + DeleteAt: nil, + } + r, err = db.Table(table).Data(dataReplace).Replace() + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 2) + + oneReplace, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(oneReplace["id"].Int(), 1) + t.Assert(oneReplace["name"].String(), "name_100") + t.Assert(oneReplace["delete_at"].String(), "") + t.AssertGE(oneReplace["create_at"].GTime().Timestamp(), oneInsert["create_at"].GTime().Timestamp()) + t.AssertGE(oneReplace["update_at"].GTime().Timestamp(), oneInsert["update_at"].GTime().Timestamp()) + + time.Sleep(2 * time.Second) + + // Delete + r, err = db.Table(table).Delete("id", 1) + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + // Delete Select + one4, err := db.Table(table).FindOne(1) + t.Assert(err, nil) + t.Assert(len(one4), 0) + one5, err := db.Table(table).Unscoped().FindOne(1) + t.Assert(err, nil) + t.Assert(one5["id"].Int(), 1) + t.AssertGE(one5["delete_at"].GTime().Timestamp(), gtime.Now().Timestamp()-2) + // Delete Count + i, err := db.Table(table).FindCount() + t.Assert(err, nil) + t.Assert(i, 0) + i, err = db.Table(table).Unscoped().FindCount() + t.Assert(err, nil) + t.Assert(i, 1) + + // Delete Unscoped + r, err = db.Table(table).Unscoped().Delete("id", 1) + t.Assert(err, nil) + n, _ = r.RowsAffected() + t.Assert(n, 1) + one6, err := db.Table(table).Unscoped().FindOne(1) + t.Assert(err, nil) + t.Assert(len(one6), 0) + i, err = db.Table(table).Unscoped().FindCount() + t.Assert(err, nil) + t.Assert(i, 0) + }) +} diff --git a/database/gdb/gdb_unit_z_mysql_transaction_test.go b/database/gdb/gdb_unit_z_mysql_transaction_test.go index 3dd776fbf..28d817fb2 100644 --- a/database/gdb/gdb_unit_z_mysql_transaction_test.go +++ b/database/gdb/gdb_unit_z_mysql_transaction_test.go @@ -120,7 +120,7 @@ func Test_TX_Insert(t *testing.T) { if err != nil { gtest.Error(err) } - user := tx.From(table) + user := tx.Table(table) if _, err := user.Data(g.Map{ "id": 1, "passport": "t1", diff --git a/debug/gdebug/gdebug_caller.go b/debug/gdebug/gdebug_caller.go index 13f7391cf..d5a55594e 100644 --- a/debug/gdebug/gdebug_caller.go +++ b/debug/gdebug/gdebug_caller.go @@ -42,12 +42,14 @@ func init() { } } -// CallerPath returns the function name and the absolute file path along with its line number of the caller. +// CallerPath returns the function name and the absolute file path along with its line +// number of the caller. func Caller(skip ...int) (function string, path string, line int) { return CallerWithFilter("", skip...) } -// CallerPathWithFilter returns the function name and the absolute file path along with its line number of the caller. +// CallerPathWithFilter returns the function name and the absolute file path along with +// its line number of the caller. // // The parameter is used to filter the path of the caller. func CallerWithFilter(filter string, skip ...int) (function string, path string, line int) { @@ -84,7 +86,10 @@ func CallerWithFilter(filter string, skip ...int) (function string, path string, return "", "", -1 } -// callerFromIndex returns the caller position and according information exclusive of the debug package. +// callerFromIndex returns the caller position and according information exclusive of the +// debug package. +// +// VERY NOTE THAT, the returned index value should be as the caller's start point. func callerFromIndex(filters []string) (pc uintptr, file string, line int, index int) { var filtered, ok bool for index = 0; index < gMAX_DEPTH; index++ { @@ -102,6 +107,9 @@ func callerFromIndex(filters []string) (pc uintptr, file string, line int, index if strings.Contains(file, gFILTER_KEY) { continue } + if index > 0 { + index-- + } return } } diff --git a/debug/gdebug/gdebug_stack.go b/debug/gdebug/gdebug_stack.go index 498f03bf2..dbfb23d22 100644 --- a/debug/gdebug/gdebug_stack.go +++ b/debug/gdebug/gdebug_stack.go @@ -35,25 +35,33 @@ func StackWithFilter(filter string, skip ...int) string { // StackWithFilters returns a formatted stack trace of the goroutine that calls it. // It calls runtime.Stack with a large enough buffer to capture the entire trace. // -// The parameter is a slice of strings, which are used to filter the path of the caller. +// The parameter is a slice of strings, which are used to filter the path of the +// caller. +// +// TODO Improve the performance using debug.Stack. func StackWithFilters(filters []string, skip ...int) string { number := 0 if len(skip) > 0 { number = skip[0] } - name := "" - space := " " - index := 1 - buffer := bytes.NewBuffer(nil) - filtered := false - ok := true - pc, file, line, start := callerFromIndex(filters) + var ( + name = "" + space = " " + index = 1 + buffer = bytes.NewBuffer(nil) + filtered = false + ok = true + pc, file, line, start = callerFromIndex(filters) + ) for i := start + number; i < gMAX_DEPTH; i++ { if i != start { pc, file, line, ok = runtime.Caller(i) } if ok { - if goRootForFilter != "" && len(file) >= len(goRootForFilter) && file[0:len(goRootForFilter)] == goRootForFilter { + // GOROOT filter. + if goRootForFilter != "" && + len(file) >= len(goRootForFilter) && + file[0:len(goRootForFilter)] == goRootForFilter { continue } filtered = false diff --git a/debug/gdebug/gdebug_z_bench_test.go b/debug/gdebug/gdebug_z_bench_test.go index 8c6dbda69..7e26cdd60 100644 --- a/debug/gdebug/gdebug_z_bench_test.go +++ b/debug/gdebug/gdebug_z_bench_test.go @@ -10,6 +10,7 @@ package gdebug import ( "runtime" + "runtime/debug" "testing" ) @@ -49,6 +50,12 @@ func Benchmark_Stack(b *testing.B) { } } +func Benchmark_StackOfStdlib(b *testing.B) { + for i := 0; i < b.N; i++ { + debug.Stack() + } +} + func Benchmark_StackWithFilter(b *testing.B) { for i := 0; i < b.N; i++ { StackWithFilter("test") diff --git a/encoding/gjson/gjson_api_new_load.go b/encoding/gjson/gjson_api_new_load.go index 44bc82fe3..680792235 100644 --- a/encoding/gjson/gjson_api_new_load.go +++ b/encoding/gjson/gjson_api_new_load.go @@ -74,7 +74,7 @@ func NewWithTag(data interface{}, tags string, safe ...bool) *Json { i := interface{}(nil) // Note that it uses Map function implementing the converting. // Note that it here should not use MapDeep function if you really know what it means. - i = gconv.Map(data, tags) + i = gconv.MapDeep(data, tags) j = &Json{ p: &i, c: byte(gDEFAULT_SPLIT_CHAR), diff --git a/encoding/gurl/url.go b/encoding/gurl/url.go index 1d43434b2..c1a480519 100644 --- a/encoding/gurl/url.go +++ b/encoding/gurl/url.go @@ -12,12 +12,17 @@ import ( "strings" ) -// url encode string, is + not %20 +// Encode escapes the string so it can be safely placed +// inside a URL query. func Encode(str string) string { return url.QueryEscape(str) } -// url decode string +// Decode does the inverse transformation of Encode, +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. +// It returns an error if any % is not followed by two hexadecimal +// digits. func Decode(str string) (string, error) { return url.QueryUnescape(str) } diff --git a/frame/g/g_object.go b/frame/g/g_object.go index da846e2ef..404fd7abc 100644 --- a/frame/g/g_object.go +++ b/frame/g/g_object.go @@ -91,6 +91,13 @@ func DB(name ...string) gdb.DB { return gins.Database(name...) } +// Table creates and returns a model from specified database or default database configuration. +// The optional parameter specifies the configuration group name of the database, +// which is "default" in default. +func Table(tables string, db ...string) *gdb.Model { + return DB(db...).Table(tables) +} + // Redis returns an instance of redis client with specified configuration group name. func Redis(name ...string) *gredis.Redis { return gins.Redis(name...) diff --git a/frame/gmvc/view.go b/frame/gmvc/view.go index b4e35bc71..73caebdea 100644 --- a/frame/gmvc/view.go +++ b/frame/gmvc/view.go @@ -99,7 +99,7 @@ func (view *View) BindFuncMap(funcMap gview.FuncMap) { // Display parses and writes the parsed template file content to http response. func (view *View) Display(file ...string) error { - name := "index.tpl" + name := view.view.GetDefaultFile() if len(file) > 0 { name = file[0] } diff --git a/go.mod b/go.mod index f6db30b1e..a9953e5df 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/gogf/gf -go 1.13 +go 1.11 require ( github.com/BurntSushi/toml v0.3.1 diff --git a/internal/empty/empty.go b/internal/empty/empty.go index 675317138..52a4bf86f 100644 --- a/internal/empty/empty.go +++ b/internal/empty/empty.go @@ -12,8 +12,8 @@ import ( ) // IsEmpty checks whether given empty. -// It returns true if is in: 0, nil, false, "", len(slice/map/chan) == 0. -// Or else it returns true. +// It returns true if is in: 0, nil, false, "", len(slice/map/chan) == 0, +// or else it returns false. func IsEmpty(value interface{}) bool { if value == nil { return true @@ -49,6 +49,8 @@ func IsEmpty(value interface{}) bool { return value == "" case []byte: return len(value) == 0 + case []rune: + return len(value) == 0 default: // Finally using reflect. var rv reflect.Value diff --git a/internal/rwmutex/rwmutex.go b/internal/rwmutex/rwmutex.go index f92c16147..d1fed095f 100644 --- a/internal/rwmutex/rwmutex.go +++ b/internal/rwmutex/rwmutex.go @@ -20,7 +20,15 @@ type RWMutex struct { // The parameter is used to specify whether using this mutex in concurrent safety, // which is false in default. func New(safe ...bool) *RWMutex { - mu := new(RWMutex) + mu := Create(safe...) + return &mu +} + +// Create creates and returns a new RWMutex object. +// The parameter is used to specify whether using this mutex in concurrent safety, +// which is false in default. +func Create(safe ...bool) RWMutex { + mu := RWMutex{} if len(safe) > 0 && safe[0] { mu.RWMutex = new(sync.RWMutex) } diff --git a/internal/structs/structs_map.go b/internal/structs/structs_map.go index 766625166..012855c6e 100644 --- a/internal/structs/structs_map.go +++ b/internal/structs/structs_map.go @@ -14,6 +14,8 @@ import ( // MapField retrieves struct field as map[name/tag]*Field from , and returns the map. // +// The parameter should be type of struct/*struct. +// // The parameter specifies the priority tag array for retrieving from high to low. // // The parameter specifies whether retrieving the struct field recursively. diff --git a/internal/structs/structs_tag.go b/internal/structs/structs_tag.go index 1fbfc2a54..fdce13fde 100644 --- a/internal/structs/structs_tag.go +++ b/internal/structs/structs_tag.go @@ -14,6 +14,8 @@ import ( // TagFields retrieves struct tags as []*Field from , and returns it. // +// The parameter should be type of struct/*struct. +// // The parameter specifies whether retrieving the struct field recursively. // // Note that it only retrieves the exported attributes with first letter up-case from struct. @@ -23,6 +25,7 @@ func TagFields(pointer interface{}, priority []string, recursive bool) []*Field // doTagFields retrieves the tag and corresponding attribute name from . It also filters repeated // tag internally. +// The parameter should be type of struct/*struct. func doTagFields(pointer interface{}, priority []string, recursive bool, tagMap map[string]struct{}) []*Field { var fields []*structs.Field if v, ok := pointer.(reflect.Value); ok { @@ -85,6 +88,8 @@ func doTagFields(pointer interface{}, priority []string, recursive bool, tagMap // TagMapName retrieves struct tags as map[tag]attribute from , and returns it. // +// The parameter should be type of struct/*struct. +// // The parameter specifies whether retrieving the struct field recursively. // // Note that it only retrieves the exported attributes with first letter up-case from struct. @@ -99,6 +104,8 @@ func TagMapName(pointer interface{}, priority []string, recursive bool) map[stri // TagMapField retrieves struct tags as map[tag]*Field from , and returns it. // +// The parameter should be type of struct/*struct. +// // The parameter specifies whether retrieving the struct field recursively. // // Note that it only retrieves the exported attributes with first letter up-case from struct. diff --git a/internal/utils/utils_str.go b/internal/utils/utils_str.go index 5f6a878a9..b14e13afb 100644 --- a/internal/utils/utils_str.go +++ b/internal/utils/utils_str.go @@ -6,7 +6,16 @@ package utils -import "strings" +import ( + "regexp" + "strings" +) + +var ( + // replaceCharReg is the regular expression object for replacing chars in key. + // It is used for function EqualFoldWithoutChars. + replaceCharReg, _ = regexp.Compile(`[\-\.\_\s]+`) +) // IsLetterUpper checks whether the given byte b is in upper case. func IsLetterUpper(b byte) bool { @@ -73,3 +82,12 @@ func ReplaceByMap(origin string, replaces map[string]string) string { } return origin } + +// EqualFoldWithoutChars checks string and equal case-insensitively, +// with/without chars '-'/'_'/'.'/' '. +func EqualFoldWithoutChars(s1, s2 string) bool { + return strings.EqualFold( + replaceCharReg.ReplaceAllString(s1, ""), + replaceCharReg.ReplaceAllString(s2, ""), + ) +} diff --git a/net/ghttp/ghttp_client_chain.go b/net/ghttp/ghttp_client_chain.go index b5d6e8699..9191d67bf 100644 --- a/net/ghttp/ghttp_client_chain.go +++ b/net/ghttp/ghttp_client_chain.go @@ -11,6 +11,17 @@ import ( "time" ) +// Prefix is a chaining function, +// which sets the URL prefix for next request of this client. +func (c *Client) Prefix(prefix string) *Client { + newClient := c + if c.parent == nil { + newClient = c.Clone() + } + newClient.SetPrefix(prefix) + return newClient +} + // Header is a chaining function, // which sets custom HTTP headers with map for next request. func (c *Client) Header(m map[string]string) *Client { @@ -113,3 +124,14 @@ func (c *Client) Ctx(ctx context.Context) *Client { newClient.SetCtx(ctx) return newClient } + +// Retry is a chaining function, +// which sets retry count and interval when failure for next request. +func (c *Client) Retry(retryCount int, retryInterval time.Duration) *Client { + newClient := c + if c.parent == nil { + newClient = c.Clone() + } + newClient.SetRetry(retryCount, retryInterval) + return c +} diff --git a/net/ghttp/ghttp_client_config.go b/net/ghttp/ghttp_client_config.go index d650b7f19..49c00ee46 100644 --- a/net/ghttp/ghttp_client_config.go +++ b/net/ghttp/ghttp_client_config.go @@ -28,7 +28,7 @@ type Client struct { authPass string // HTTP basic authentication: pass. browserMode bool // Whether auto saving and sending cookie content. retryCount int // Retry count when request fails. - retryInterval int // Retry interval when request fails. + retryInterval time.Duration // Retry interval when request fails. } // NewClient creates and returns a new HTTP client object. @@ -43,7 +43,6 @@ func NewClient() *Client { DisableKeepAlives: true, }, }, - ctx: context.Background(), header: make(map[string]string), cookies: make(map[string]string), } @@ -143,7 +142,7 @@ func (c *Client) SetCtx(ctx context.Context) *Client { } // SetRetry sets retry count and interval. -func (c *Client) SetRetry(retryCount int, retryInterval int) *Client { +func (c *Client) SetRetry(retryCount int, retryInterval time.Duration) *Client { c.retryCount = retryCount c.retryInterval = retryInterval return c diff --git a/net/ghttp/ghttp_client_request.go b/net/ghttp/ghttp_client_request.go index 450321f9f..4ed9970c5 100644 --- a/net/ghttp/ghttp_client_request.go +++ b/net/ghttp/ghttp_client_request.go @@ -157,7 +157,8 @@ func (c *Client) DoRequest(method, url string, data ...interface{}) (resp *Clien if err = writer.Close(); err != nil { return nil, err } - if req, err = http.NewRequestWithContext(c.ctx, method, url, buffer); err != nil { + + if req, err = http.NewRequest(method, url, buffer); err != nil { return nil, err } else { req.Header.Set("Content-Type", writer.FormDataContentType()) @@ -165,9 +166,7 @@ func (c *Client) DoRequest(method, url string, data ...interface{}) (resp *Clien } else { // Normal request. paramBytes := []byte(param) - if req, err = http.NewRequestWithContext( - c.ctx, method, url, bytes.NewReader(paramBytes), - ); err != nil { + if req, err = http.NewRequest(method, url, bytes.NewReader(paramBytes)); err != nil { return nil, err } else { if v, ok := c.header["Content-Type"]; ok { @@ -184,6 +183,10 @@ func (c *Client) DoRequest(method, url string, data ...interface{}) (resp *Clien } } } + // Context. + if c.ctx != nil { + req = req.WithContext(c.ctx) + } // Custom header. if len(c.header) > 0 { for k, v := range c.header { @@ -191,7 +194,7 @@ func (c *Client) DoRequest(method, url string, data ...interface{}) (resp *Clien } } // It's necessary set the req.Host if you want to custom the host value of the request. - // It uses the "Host" value of the header. + // It uses the "Host" value from header if it's not set in the request. if host := req.Header.Get("Host"); host != "" && req.Host == "" { req.Host = host } @@ -220,6 +223,7 @@ func (c *Client) DoRequest(method, url string, data ...interface{}) (resp *Clien if r, err = c.Do(req); err != nil { if c.retryCount > 0 { c.retryCount-- + time.Sleep(c.retryInterval) } else { // we need a copy of the request when the request fails. resp.req = req diff --git a/net/ghttp/ghttp_client_response.go b/net/ghttp/ghttp_client_response.go index f13981dae..132405f6a 100644 --- a/net/ghttp/ghttp_client_response.go +++ b/net/ghttp/ghttp_client_response.go @@ -24,12 +24,8 @@ type ClientResponse struct { // initCookie initializes the cookie map attribute of ClientResponse. func (r *ClientResponse) initCookie() { if r.cookies == nil { - now := time.Now() r.cookies = make(map[string]string) for _, v := range r.Cookies() { - if v.Expires.UnixNano() < now.UnixNano() { - continue - } r.cookies[v.Name] = v.Value } } diff --git a/net/ghttp/ghttp_func.go b/net/ghttp/ghttp_func.go index e2d46a8ef..3a55296fb 100644 --- a/net/ghttp/ghttp_func.go +++ b/net/ghttp/ghttp_func.go @@ -19,9 +19,15 @@ import ( // The optional parameter specifies whether ignore the url encoding for the data. func BuildParams(params interface{}, noUrlEncode ...bool) (encodedParamStr string) { // If given string/[]byte, converts and returns it directly as string. - switch params.(type) { + switch v := params.(type) { case string, []byte: return gconv.String(params) + case []interface{}: + if len(v) > 0 { + params = v[0] + } else { + params = nil + } } // Else converts it to map and does the url encoding. m, urlEncode := gconv.Map(params), true diff --git a/net/ghttp/ghttp_request_auth.go b/net/ghttp/ghttp_request_auth.go index 0f35ac65a..82ab1c666 100644 --- a/net/ghttp/ghttp_request_auth.go +++ b/net/ghttp/ghttp_request_auth.go @@ -14,18 +14,6 @@ import ( "github.com/gogf/gf/encoding/gbase64" ) -// setBasicAuth sets the http basic authentication tips. -func (r *Request) setBasicAuth(tips ...string) { - realm := "" - if len(tips) > 0 && tips[0] != "" { - realm = tips[0] - } else { - realm = "Need Login" - } - r.Response.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) - r.Response.WriteHeader(http.StatusUnauthorized) -} - // BasicAuth enables the http basic authentication feature with given passport and password // and asks client for authentication. It returns true if authentication success, else returns // false if failure. @@ -62,5 +50,16 @@ func (r *Request) BasicAuth(user, pass string, tips ...string) bool { r.Response.WriteStatus(http.StatusForbidden) return false } - return false +} + +// setBasicAuth sets the http basic authentication tips. +func (r *Request) setBasicAuth(tips ...string) { + realm := "" + if len(tips) > 0 && tips[0] != "" { + realm = tips[0] + } else { + realm = "Need Login" + } + r.Response.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, realm)) + r.Response.WriteHeader(http.StatusUnauthorized) } diff --git a/net/ghttp/ghttp_request_hook.go b/net/ghttp/ghttp_request_hook.go deleted file mode 100644 index 138a802a5..000000000 --- a/net/ghttp/ghttp_request_hook.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package ghttp - -// 获得当前请求,指定类型的的钩子函数列表 -func (r *Request) getHookHandlers(hook string) []*handlerParsedItem { - if !r.hasHookHandler { - return nil - } - parsedItems := make([]*handlerParsedItem, 0, 4) - for _, v := range r.handlers { - if v.handler.hookName != hook { - continue - } - item := v - parsedItems = append(parsedItems, item) - } - return parsedItems -} diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index 050c16f0e..495a6431f 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -24,6 +24,7 @@ type Middleware struct { } // Next calls the next workflow handler. +// It's an important function controlling the workflow of the server request execution. func (m *Middleware) Next() { var item *handlerParsedItem var loop = true diff --git a/net/ghttp/ghttp_request_param.go b/net/ghttp/ghttp_request_param.go index 6b66a820b..2a0872038 100644 --- a/net/ghttp/ghttp_request_param.go +++ b/net/ghttp/ghttp_request_param.go @@ -32,7 +32,7 @@ var ( // to given struct, and then calls gvalid.CheckStruct validating the struct according // to the validation tag of the struct. // -// See GetStruct, gvalid.CheckStruct. +// See r.GetStruct, gvalid.CheckStruct. func (r *Request) Parse(pointer interface{}) error { if err := r.GetStruct(pointer); err != nil { return err @@ -45,7 +45,7 @@ func (r *Request) Parse(pointer interface{}) error { // Get is alias of GetRequest, which is one of the most commonly used functions for // retrieving parameter. -// See GetRequest. +// See r.GetRequest. func (r *Request) Get(key string, def ...interface{}) interface{} { return r.GetRequest(key, def...) } @@ -58,17 +58,20 @@ func (r *Request) GetVar(key string, def ...interface{}) *gvar.Var { // GetRaw is alias of GetBody. // See GetBody. +// Deprecated. func (r *Request) GetRaw() []byte { return r.GetBody() } // GetRawString is alias of GetBodyString. // See GetBodyString. +// Deprecated. func (r *Request) GetRawString() string { return r.GetBodyString() } -// GetRaw retrieves and returns request body content as bytes. +// GetBody retrieves and returns request body content as bytes. +// It can be called multiple times retrieving the same body content. func (r *Request) GetBody() []byte { if r.bodyContent == nil { r.bodyContent, _ = ioutil.ReadAll(r.Body) @@ -77,92 +80,127 @@ func (r *Request) GetBody() []byte { return r.bodyContent } -// GetRawString retrieves and returns request body content as string. +// GetBodyString retrieves and returns request body content as string. +// It can be called multiple times retrieving the same body content. func (r *Request) GetBodyString() string { - return gconv.UnsafeBytesToStr(r.GetRaw()) + return gconv.UnsafeBytesToStr(r.GetBody()) } // GetJson parses current request content as JSON format, and returns the JSON object. // Note that the request content is read from request BODY, not from any field of FORM. func (r *Request) GetJson() (*gjson.Json, error) { - return gjson.LoadJson(r.GetRaw()) + return gjson.LoadJson(r.GetBody()) } +// GetString is an alias and convenient function for GetRequestString. +// See GetRequestString. func (r *Request) GetString(key string, def ...interface{}) string { return r.GetRequestString(key, def...) } +// GetBool is an alias and convenient function for GetRequestBool. +// See GetRequestBool. func (r *Request) GetBool(key string, def ...interface{}) bool { return r.GetRequestBool(key, def...) } +// GetInt is an alias and convenient function for GetRequestInt. +// See GetRequestInt. func (r *Request) GetInt(key string, def ...interface{}) int { return r.GetRequestInt(key, def...) } +// GetInt32 is an alias and convenient function for GetRequestInt32. +// See GetRequestInt32. func (r *Request) GetInt32(key string, def ...interface{}) int32 { return r.GetRequestInt32(key, def...) } +// GetInt64 is an alias and convenient function for GetRequestInt64. +// See GetRequestInt64. func (r *Request) GetInt64(key string, def ...interface{}) int64 { return r.GetRequestInt64(key, def...) } +// GetInts is an alias and convenient function for GetRequestInts. +// See GetRequestInts. func (r *Request) GetInts(key string, def ...interface{}) []int { return r.GetRequestInts(key, def...) } +// GetUint is an alias and convenient function for GetRequestUint. +// See GetRequestUint. func (r *Request) GetUint(key string, def ...interface{}) uint { return r.GetRequestUint(key, def...) } +// GetUint32 is an alias and convenient function for GetRequestUint32. +// See GetRequestUint32. func (r *Request) GetUint32(key string, def ...interface{}) uint32 { return r.GetRequestUint32(key, def...) } +// GetUint64 is an alias and convenient function for GetRequestUint64. +// See GetRequestUint64. func (r *Request) GetUint64(key string, def ...interface{}) uint64 { return r.GetRequestUint64(key, def...) } +// GetFloat32 is an alias and convenient function for GetRequestFloat32. +// See GetRequestFloat32. func (r *Request) GetFloat32(key string, def ...interface{}) float32 { return r.GetRequestFloat32(key, def...) } +// GetFloat64 is an alias and convenient function for GetRequestFloat64. +// See GetRequestFloat64. func (r *Request) GetFloat64(key string, def ...interface{}) float64 { return r.GetRequestFloat64(key, def...) } +// GetFloats is an alias and convenient function for GetRequestFloats. +// See GetRequestFloats. func (r *Request) GetFloats(key string, def ...interface{}) []float64 { return r.GetRequestFloats(key, def...) } +// GetArray is an alias and convenient function for GetRequestArray. +// See GetRequestArray. func (r *Request) GetArray(key string, def ...interface{}) []string { return r.GetRequestArray(key, def...) } +// GetStrings is an alias and convenient function for GetRequestStrings. +// See GetRequestStrings. func (r *Request) GetStrings(key string, def ...interface{}) []string { return r.GetRequestStrings(key, def...) } +// GetInterfaces is an alias and convenient function for GetRequestInterfaces. +// See GetRequestInterfaces. func (r *Request) GetInterfaces(key string, def ...interface{}) []interface{} { return r.GetRequestInterfaces(key, def...) } +// GetMap is an alias and convenient function for GetRequestMap. +// See GetRequestMap. func (r *Request) GetMap(def ...map[string]interface{}) map[string]interface{} { return r.GetRequestMap(def...) } +// GetMapStrStr is an alias and convenient function for GetRequestMapStrStr. +// See GetRequestMapStrStr. func (r *Request) GetMapStrStr(def ...map[string]interface{}) map[string]string { return r.GetRequestMapStrStr(def...) } -// GetStruct is alias of GetRequestToStruct. -// See GetRequestToStruct. +// GetStruct is an alias and convenient function for GetRequestStruct. +// See GetRequestStruct. func (r *Request) GetStruct(pointer interface{}, mapping ...map[string]string) error { return r.GetRequestStruct(pointer, mapping...) } -// GetToStruct is alias of GetRequestToStruct. +// GetToStruct is an alias and convenient function for GetRequestStruct. // See GetRequestToStruct. // Deprecated. func (r *Request) GetToStruct(pointer interface{}, mapping ...map[string]string) error { @@ -184,7 +222,7 @@ func (r *Request) parseQuery() { } } -// ParseRaw parses the request raw data into r.rawMap. +// parseBody parses the request raw data into r.rawMap. // Note that it also supports JSON data from client request. func (r *Request) parseBody() { if r.parsedBody { @@ -289,7 +327,7 @@ func (r *Request) GetMultipartForm() *multipart.Form { return r.MultipartForm } -// GetMultipartFiles returns the post files array. +// GetMultipartFiles parses and returns the post files array. // Note that the request form should be type of multipart. func (r *Request) GetMultipartFiles(name string) []*multipart.FileHeader { form := r.GetMultipartForm() diff --git a/net/ghttp/ghttp_request_param_ctx.go b/net/ghttp/ghttp_request_param_ctx.go index cdcfe78b9..91910c45d 100644 --- a/net/ghttp/ghttp_request_param_ctx.go +++ b/net/ghttp/ghttp_request_param_ctx.go @@ -11,8 +11,9 @@ import ( "github.com/gogf/gf/container/gvar" ) -// Context retrieves and returns the request's context. +// Context is alias for function GetCtx. // This function overwrites the http.Request.Context function. +// See GetCtx. func (r *Request) Context() context.Context { if r.context == nil { r.context = r.Request.Context() @@ -20,13 +21,14 @@ func (r *Request) Context() context.Context { return r.context } -// GetCtx is alias for function Context. -// See Context. +// GetCtx retrieves and returns the request's context. func (r *Request) GetCtx() context.Context { return r.Context() } // GetCtxVar retrieves and returns a Var with given key name. +// The optional parameter specifies the default value of the Var if given +// does not exist in the context. func (r *Request) GetCtxVar(key interface{}, def ...interface{}) *gvar.Var { value := r.Context().Value(key) if value == nil && len(def) > 0 { diff --git a/net/ghttp/ghttp_request_param_file.go b/net/ghttp/ghttp_request_param_file.go index b808d6cf1..0086b9e3e 100644 --- a/net/ghttp/ghttp_request_param_file.go +++ b/net/ghttp/ghttp_request_param_file.go @@ -30,10 +30,7 @@ type UploadFiles []*UploadFile // // The parameter should be a directory path or it returns error. // -// The parameter specifies whether randomly renames the file name, which -// make sense if the is a directory. -// -// Note that it will overwrite the target file if there's already a same name file exist. +// Note that it will OVERWRITE the target file if there's already a same name file exist. func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename string, err error) { if f == nil { return "", errors.New("file is empty, maybe you retrieve it from invalid field name or form enctype") @@ -93,6 +90,8 @@ func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames [] // This function is used for retrieving single uploading file object, which is // uploaded using multipart form content type. // +// It returns nil if retrieving failed or no form file with given name posted. +// // Note that the is the file field name of the multipart form from client. func (r *Request) GetUploadFile(name string) *UploadFile { uploadFiles := r.GetUploadFiles(name) @@ -106,6 +105,8 @@ func (r *Request) GetUploadFile(name string) *UploadFile { // This function is used for retrieving multiple uploading file objects, which are // uploaded using multipart form content type. // +// It returns nil if retrieving failed or no form file with given name posted. +// // Note that the is the file field name of the multipart form from client. func (r *Request) GetUploadFiles(name string) UploadFiles { multipartFiles := r.GetMultipartFiles(name) diff --git a/net/ghttp/ghttp_request_param_form.go b/net/ghttp/ghttp_request_param_form.go index 7f9eefc7c..af5ac0d50 100644 --- a/net/ghttp/ghttp_request_param_form.go +++ b/net/ghttp/ghttp_request_param_form.go @@ -22,8 +22,7 @@ func (r *Request) SetForm(key string, value interface{}) { } // GetForm retrieves and returns parameter from form. -// It returns if does not exist in the form. -// It returns nil if is not passed. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetForm(key string, def ...interface{}) interface{} { r.parseForm() if len(r.formMap) > 0 { @@ -37,66 +36,98 @@ func (r *Request) GetForm(key string, def ...interface{}) interface{} { return nil } +// GetFormVar retrieves and returns parameter from form as Var. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormVar(key string, def ...interface{}) *gvar.Var { return gvar.New(r.GetForm(key, def...)) } +// GetFormString retrieves and returns parameter from form as string. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormString(key string, def ...interface{}) string { return r.GetFormVar(key, def...).String() } +// GetFormBool retrieves and returns parameter from form as bool. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormBool(key string, def ...interface{}) bool { return r.GetFormVar(key, def...).Bool() } +// GetFormInt retrieves and returns parameter from form as int. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormInt(key string, def ...interface{}) int { return r.GetFormVar(key, def...).Int() } +// GetFormInt32 retrieves and returns parameter from form as int32. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormInt32(key string, def ...interface{}) int32 { return r.GetFormVar(key, def...).Int32() } +// GetFormInt64 retrieves and returns parameter from form as int64. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormInt64(key string, def ...interface{}) int64 { return r.GetFormVar(key, def...).Int64() } +// GetFormInts retrieves and returns parameter from form as []int. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormInts(key string, def ...interface{}) []int { return r.GetFormVar(key, def...).Ints() } +// GetFormUint retrieves and returns parameter from form as uint. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormUint(key string, def ...interface{}) uint { return r.GetFormVar(key, def...).Uint() } +// GetFormUint32 retrieves and returns parameter from form as uint32. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormUint32(key string, def ...interface{}) uint32 { return r.GetFormVar(key, def...).Uint32() } +// GetFormUint64 retrieves and returns parameter from form as uint64. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormUint64(key string, def ...interface{}) uint64 { return r.GetFormVar(key, def...).Uint64() } +// GetFormFloat32 retrieves and returns parameter from form as float32. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormFloat32(key string, def ...interface{}) float32 { return r.GetFormVar(key, def...).Float32() } +// GetFormFloat64 retrieves and returns parameter from form as float64. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormFloat64(key string, def ...interface{}) float64 { return r.GetFormVar(key, def...).Float64() } +// GetFormFloats retrieves and returns parameter from form as []float64. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormFloats(key string, def ...interface{}) []float64 { return r.GetFormVar(key, def...).Floats() } +// GetFormArray retrieves and returns parameter from form as []string. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormArray(key string, def ...interface{}) []string { return r.GetFormVar(key, def...).Strings() } +// GetFormStrings retrieves and returns parameter from form as []string. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormStrings(key string, def ...interface{}) []string { return r.GetFormVar(key, def...).Strings() } +// GetFormInterfaces retrieves and returns parameter from form as []interface{}. +// It returns if does not exist in the form and is given, or else it returns nil. func (r *Request) GetFormInterfaces(key string, def ...interface{}) []interface{} { return r.GetFormVar(key, def...).Interfaces() } diff --git a/net/ghttp/ghttp_request_param_post.go b/net/ghttp/ghttp_request_param_post.go index 913b95454..2ed8aaa37 100644 --- a/net/ghttp/ghttp_request_param_post.go +++ b/net/ghttp/ghttp_request_param_post.go @@ -16,8 +16,8 @@ import ( // It returns if does not exist in neither form nor body. // It returns nil if is not passed. // -// Note that if there're multiple parameters with the same name, the parameters are retrieved and overwrote -// in order of priority: form > body. +// Note that if there're multiple parameters with the same name, the parameters are retrieved +// and overwrote in order of priority: form > body. // // Deprecated. func (r *Request) GetPost(key string, def ...interface{}) interface{} { diff --git a/net/ghttp/ghttp_request_param_query.go b/net/ghttp/ghttp_request_param_query.go index 26d80511f..6c8afae38 100644 --- a/net/ghttp/ghttp_request_param_query.go +++ b/net/ghttp/ghttp_request_param_query.go @@ -23,11 +23,11 @@ func (r *Request) SetQuery(key string, value interface{}) { } // GetQuery retrieves and returns parameter with given name from query string -// and request body. It returns if does not exist in the query. It returns nil -// if is not passed. +// and request body. It returns if does not exist in the query and is given, +// or else it returns nil. // -// Note that if there're multiple parameters with the same name, the parameters are retrieved and overwrote -// in order of priority: query > body. +// Note that if there're multiple parameters with the same name, the parameters are retrieved +// and overwrote in order of priority: query > body. func (r *Request) GetQuery(key string, def ...interface{}) interface{} { r.parseQuery() if len(r.queryMap) > 0 { diff --git a/net/ghttp/ghttp_response_view.go b/net/ghttp/ghttp_response_view.go index 2266e18ce..2bdde6473 100644 --- a/net/ghttp/ghttp_response_view.go +++ b/net/ghttp/ghttp_response_view.go @@ -11,6 +11,7 @@ import ( "github.com/gogf/gf/os/gcfg" "github.com/gogf/gf/os/gview" "github.com/gogf/gf/util/gmode" + "github.com/gogf/gf/util/gutil" ) // WriteTpl parses and responses given template file. @@ -74,27 +75,19 @@ func (r *Response) ParseTplContent(content string, params ...gview.Params) (stri // buildInVars merges build-in variables into and returns the new template variables. func (r *Response) buildInVars(params ...map[string]interface{}) map[string]interface{} { - var vars map[string]interface{} - if len(params) > 0 && params[0] != nil { - vars = params[0] - } else { - vars = make(map[string]interface{}) - } + m := gutil.MapMergeCopy(params...) // Retrieve custom template variables from request object. - if len(r.Request.viewParams) > 0 { - for k, v := range r.Request.viewParams { - vars[k] = v - } - } + gutil.MapMerge(m, r.Request.viewParams, map[string]interface{}{ + "Form": r.Request.GetFormMap(), + "Query": r.Request.GetQueryMap(), + "Request": r.Request.GetMap(), + "Cookie": r.Request.Cookie.Map(), + "Session": r.Request.Session.Map(), + }) // Note that it should assign no Config variable to template // if there's no configuration file. if c := gcfg.Instance(); c.Available() { - vars["Config"] = c.GetMap(".") + m["Config"] = c.GetMap(".") } - vars["Form"] = r.Request.GetFormMap() - vars["Query"] = r.Request.GetQueryMap() - vars["Request"] = r.Request.GetMap() - vars["Cookie"] = r.Request.Cookie.Map() - vars["Session"] = r.Request.Session.Map() - return vars + return m } diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index deb6a8d19..c989942e0 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -306,7 +306,7 @@ func (s *Server) Start() error { // Default HTTP handler. if s.config.Handler == nil { - s.config.Handler = http.HandlerFunc(s.defaultHandler) + s.config.Handler = s } // Install external plugins. diff --git a/net/ghttp/ghttp_server_admin_unix.go b/net/ghttp/ghttp_server_admin_unix.go index 748849193..7cca3a84e 100644 --- a/net/ghttp/ghttp_server_admin_unix.go +++ b/net/ghttp/ghttp_server_admin_unix.go @@ -9,6 +9,7 @@ package ghttp import ( + "github.com/gogf/gf/internal/intlog" "os" "os/signal" "syscall" @@ -26,14 +27,16 @@ func handleProcessSignal() { syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM, + syscall.SIGABRT, syscall.SIGUSR1, syscall.SIGUSR2, ) for { sig = <-procSignalChan + intlog.Printf(`signal received: %s`, sig.String()) switch sig { // 进程终止,停止所有子进程运行 - case syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM: + case syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGABRT: shutdownWebServers(sig.String()) return diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index 26e4660fc..773d93063 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -118,7 +118,8 @@ func Config() ServerConfig { } } -// ConfigFromMap creates and returns a ServerConfig object with given map. +// ConfigFromMap creates and returns a ServerConfig object with given map and +// default configuration object. func ConfigFromMap(m map[string]interface{}) (ServerConfig, error) { config := Config() if err := gconv.Struct(m, &config); err != nil { @@ -127,9 +128,13 @@ func ConfigFromMap(m map[string]interface{}) (ServerConfig, error) { return config, nil } -// Handler returns the request handler of the server. -func (s *Server) Handler() http.Handler { - return s.config.Handler +// SetConfigWithMap sets the configuration for the server using map. +func (s *Server) SetConfigWithMap(m map[string]interface{}) error { + // Update the current configuration object. + if err := gconv.Struct(m, &s.config); err != nil { + return err + } + return s.SetConfig(s.config) } // SetConfig sets the configuration for the server. @@ -156,15 +161,6 @@ func (s *Server) SetConfig(c ServerConfig) error { return nil } -// SetConfigWithMap sets the configuration for the server using map. -func (s *Server) SetConfigWithMap(m map[string]interface{}) error { - config, err := ConfigFromMap(m) - if err != nil { - return err - } - return s.SetConfig(config) -} - // SetAddr sets the listening address for the server. // The address is like ':80', '0.0.0.0:80', '127.0.0.1:80', '180.18.99.10:80', etc. func (s *Server) SetAddr(address string) { @@ -278,3 +274,11 @@ func (s *Server) SetView(view *gview.View) { func (s *Server) GetName() string { return s.name } + +// Handler returns the request handler of the server. +func (s *Server) Handler() http.Handler { + if s.config.Handler == nil { + return s + } + return s.config.Handler +} diff --git a/net/ghttp/ghttp_server_handler.go b/net/ghttp/ghttp_server_handler.go index 5974d0d8f..6a788fe41 100644 --- a/net/ghttp/ghttp_server_handler.go +++ b/net/ghttp/ghttp_server_handler.go @@ -25,15 +25,7 @@ import ( ) // 默认HTTP Server处理入口,http包底层默认使用了gorutine异步处理请求,所以这里不再异步执行 -func (s *Server) defaultHandler(w http.ResponseWriter, r *http.Request) { - s.handleRequest(w, r) -} - -// 执行处理HTTP请求, -// 首先,查找是否有对应域名的处理接口配置; -// 其次,如果没有对应的自定义处理接口配置,那么走默认的域名处理接口配置; -// 最后,如果以上都没有找到处理接口,那么进行文件处理; -func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 重写规则判断 if len(s.config.Rewrites) > 0 { if rewrite, ok := s.config.Rewrites[r.URL.Path]; ok { diff --git a/net/ghttp/ghttp_server_router_hook.go b/net/ghttp/ghttp_server_router_hook.go index e918c6e30..c073a556d 100644 --- a/net/ghttp/ghttp_server_router_hook.go +++ b/net/ghttp/ghttp_server_router_hook.go @@ -63,6 +63,22 @@ func (s *Server) callHookHandler(hook string, r *Request) { } } +// 获得当前请求,指定类型的的钩子函数列表 +func (r *Request) getHookHandlers(hook string) []*handlerParsedItem { + if !r.hasHookHandler { + return nil + } + parsedItems := make([]*handlerParsedItem, 0, 4) + for _, v := range r.handlers { + if v.handler.hookName != hook { + continue + } + item := v + parsedItems = append(parsedItems, item) + } + return parsedItems +} + // 友好地调用方法 func (s *Server) niceCallHookHandler(f HandlerFunc, r *Request) (err interface{}) { defer func() { diff --git a/net/ghttp/ghttp_server_service_controller.go b/net/ghttp/ghttp_server_service_controller.go index e86f14cad..66fbea7d7 100644 --- a/net/ghttp/ghttp_server_service_controller.go +++ b/net/ghttp/ghttp_server_service_controller.go @@ -87,7 +87,7 @@ func (s *Server) doBindController( pkgPath, ctlName, methodName, v.Method(i).Type().String()) } else { // 否则只是Debug提示 - s.Logger().Debugf(`ignore route method: %s.%s.%s defined as "%s", no match "func()"`, + s.Logger().Debugf(`ignore route method: %s.%s.%s defined as "%s", no match "func()" for controller registry`, pkgPath, ctlName, methodName, v.Method(i).Type().String()) } continue diff --git a/net/ghttp/ghttp_server_service_object.go b/net/ghttp/ghttp_server_service_object.go index d6a167cdb..ab373e128 100644 --- a/net/ghttp/ghttp_server_service_object.go +++ b/net/ghttp/ghttp_server_service_object.go @@ -96,7 +96,7 @@ func (s *Server) doBindObject( } else { // 否则只是Debug提示 s.Logger().Debugf( - `ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`, + `ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)" for object registry`, pkgPath, objName, methodName, v.Method(i).Type().String(), ) } diff --git a/net/ghttp/ghttp_unit_client_test.go b/net/ghttp/ghttp_unit_client_test.go index f66c62a54..98ce0b53d 100644 --- a/net/ghttp/ghttp_unit_client_test.go +++ b/net/ghttp/ghttp_unit_client_test.go @@ -87,6 +87,26 @@ func Test_Client_Cookie(t *testing.T) { }) } +func Test_Client_MapParam(t *testing.T) { + p, _ := ports.PopRand() + s := g.Server(p) + s.BindHandler("/map", func(r *ghttp.Request) { + r.Response.Write(r.Get("test")) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + c := ghttp.NewClient() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + t.Assert(c.GetContent("/map", g.Map{"test": "1234567890"}), "1234567890") + }) +} + func Test_Client_Cookies(t *testing.T) { p, _ := ports.PopRand() s := g.Server(p) diff --git a/net/ghttp/ghttp_unit_https_test.go b/net/ghttp/ghttp_unit_https_test.go new file mode 100644 index 000000000..511c43646 --- /dev/null +++ b/net/ghttp/ghttp_unit_https_test.go @@ -0,0 +1,94 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "fmt" + "github.com/gogf/gf/debug/gdebug" + "github.com/gogf/gf/os/gtime" + "github.com/gogf/gf/text/gstr" + "testing" + "time" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" +) + +func Test_HTTPS_Basic(t *testing.T) { + p, _ := ports.PopRand() + s := g.Server(p) + s.Group("/", func(group *ghttp.RouterGroup) { + group.GET("/test", func(r *ghttp.Request) { + r.Response.Write("test") + }) + }) + s.EnableHTTPS( + gdebug.TestDataPath("https", "server.crt"), + gdebug.TestDataPath("https", "server.key"), + ) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + + // HTTP + gtest.C(t, func(t *gtest.T) { + c := g.Client() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + t.AssertIN(gstr.Trim(c.GetContent("/")), g.Slice{"", "Client sent an HTTP request to an HTTPS server."}) + t.AssertIN(gstr.Trim(c.GetContent("/test")), g.Slice{"", "Client sent an HTTP request to an HTTPS server."}) + }) + // HTTPS + gtest.C(t, func(t *gtest.T) { + c := g.Client() + c.SetPrefix(fmt.Sprintf("https://127.0.0.1:%d", p)) + t.Assert(c.GetContent("/"), "Not Found") + t.Assert(c.GetContent("/test"), "test") + }) +} + +func Test_HTTPS_HTTP_Basic(t *testing.T) { + var ( + portHttp, _ = ports.PopRand() + portHttps, _ = ports.PopRand() + ) + s := g.Server(gtime.TimestampNanoStr()) + s.Group("/", func(group *ghttp.RouterGroup) { + group.GET("/test", func(r *ghttp.Request) { + r.Response.Write("test") + }) + }) + s.EnableHTTPS( + gdebug.TestDataPath("https", "server.crt"), + gdebug.TestDataPath("https", "server.key"), + ) + s.SetPort(portHttp) + s.SetHTTPSPort(portHttps) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + + // HTTP + gtest.C(t, func(t *gtest.T) { + c := g.Client() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", portHttp)) + t.Assert(c.GetContent("/"), "Not Found") + t.Assert(c.GetContent("/test"), "test") + }) + // HTTPS + gtest.C(t, func(t *gtest.T) { + c := g.Client() + c.SetPrefix(fmt.Sprintf("https://127.0.0.1:%d", portHttps)) + t.Assert(c.GetContent("/"), "Not Found") + t.Assert(c.GetContent("/test"), "test") + }) +} diff --git a/net/ghttp/ghttp_unit_param_test.go b/net/ghttp/ghttp_unit_param_test.go index 8ffc3a2df..6e69e06ba 100644 --- a/net/ghttp/ghttp_unit_param_test.go +++ b/net/ghttp/ghttp_unit_param_test.go @@ -469,3 +469,30 @@ func Test_Params_Priority(t *testing.T) { t.Assert(client.GetContent("/request-map?a=1&b=2&c=3", "a=100&b=200&c=300"), `{"a":"100","b":"200"}`) }) } + +func Test_Params_GetRequestMap(t *testing.T) { + p, _ := ports.PopRand() + s := g.Server(p) + s.BindHandler("/map", func(r *ghttp.Request) { + r.Response.Write(r.GetRequestMap()) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + prefix := fmt.Sprintf("http://127.0.0.1:%d", p) + client := ghttp.NewClient() + client.SetPrefix(prefix) + + t.Assert( + client.PostContent( + "/map", + "time_end2020-04-18 16:11:58&returnmsg=Success&attach=", + ), + `{"attach":"","returnmsg":"Success"}`, + ) + }) +} diff --git a/net/ghttp/testdata/https/server.crt b/net/ghttp/testdata/https/server.crt new file mode 100644 index 000000000..4d254ea21 --- /dev/null +++ b/net/ghttp/testdata/https/server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDzzCCAregAwIBAgIJAJYpWLkC2lEXMA0GCSqGSIb3DQEBCwUAMH4xCzAJBgNV +BAYTAkNIMRAwDgYDVQQIDAdTaUNodWFuMRAwDgYDVQQHDAdDaGVuZ2R1MRAwDgYD +VQQKDAdKb2huLmNuMQwwCgYDVQQLDANEZXYxDTALBgNVBAMMBEpvaG4xHDAaBgkq +hkiG9w0BCQEWDWpvaG5Aam9obmcuY24wHhcNMTgwNDIzMTMyNjA4WhcNMTkwNDIz +MTMyNjA4WjB+MQswCQYDVQQGEwJDSDEQMA4GA1UECAwHU2lDaHVhbjEQMA4GA1UE +BwwHQ2hlbmdkdTEQMA4GA1UECgwHSm9obi5jbjEMMAoGA1UECwwDRGV2MQ0wCwYD +VQQDDARKb2huMRwwGgYJKoZIhvcNAQkBFg1qb2huQGpvaG5nLmNuMIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA6cngPUrDgBhiNfn+7MMHPzOoO+oVavlS +F/tCPyKINhsePGqHkR4ILkHu9IuoBiPYR1JgrMz5goQ6mkrvq/LMfo4dCuA29ZRg ++Vps/RimBpiz+RU3FDGyqc7d+fk74dElGk6NhJJ6XO3qHqgIg1yc6d5DiZfEnlMz +CRKoZ2dQ+98o5LwES+XJBVWfZiC1pEfyppIh+ci7fXajxkRPJ+5qYWaS5cIHmJIN +DIp5Ypszg1cPs0gIr5EgPeGwZzOeqMMzsbLLE8kjSw59Pt1/+Jkdm1e0GhO18qIO +NcqaHeGaTUVjzX9XwRj8cw+q3kRoqD5aWMjUzAg9+IDrMqvo6VZQ5QIDAQABo1Aw +TjAdBgNVHQ4EFgQU1/tUQpOK0xEwLLlYDiNrckqPlDowHwYDVR0jBBgwFoAU1/tU +QpOK0xEwLLlYDiNrckqPlDowDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC +AQEA5MbG2xU3s/GDU1MV4f0wKhWCNhXfrLaYSwNYGT/eb8ZG2iHSTO0dvl0+pjO2 +EK63PDMvMhUtL1Zlyvl+OqssYcDhVfDzdFoYX6TZNbYxFwSzcx78mO6boAADk9ro +GEQWN+VHsl984SzBRZRJbtNbiw5iVuPruofeKHrrk4dLMiCsStyUaz9lUZxjo2Fi +vVJOY+mRNOBqz1HgU2+RilFTl04zWadCWPJMugQSgJcUPgxRXQ96PkC8uYevEnmR +2DUReSRULIOYEjHw0DZ6yGlqUkJcUGge3XAQEx3LlCpJasOC8Xpsh5i6WBnDPbMh +kPBjRRTooSrJOQJC5v3QW+0Kgw== +-----END CERTIFICATE----- diff --git a/net/ghttp/testdata/https/server.key b/net/ghttp/testdata/https/server.key new file mode 100644 index 000000000..e0f909629 --- /dev/null +++ b/net/ghttp/testdata/https/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA6cngPUrDgBhiNfn+7MMHPzOoO+oVavlSF/tCPyKINhsePGqH +kR4ILkHu9IuoBiPYR1JgrMz5goQ6mkrvq/LMfo4dCuA29ZRg+Vps/RimBpiz+RU3 +FDGyqc7d+fk74dElGk6NhJJ6XO3qHqgIg1yc6d5DiZfEnlMzCRKoZ2dQ+98o5LwE +S+XJBVWfZiC1pEfyppIh+ci7fXajxkRPJ+5qYWaS5cIHmJINDIp5Ypszg1cPs0gI +r5EgPeGwZzOeqMMzsbLLE8kjSw59Pt1/+Jkdm1e0GhO18qIONcqaHeGaTUVjzX9X +wRj8cw+q3kRoqD5aWMjUzAg9+IDrMqvo6VZQ5QIDAQABAoIBAHF7cMHPvL49F88j +nr7GnIntRUhwBB19EIBbknibBotc9nxVKaEjds0dbCSAdfslAyL7tbmrdaIJFXk3 +zsckgGceDLLuyz7B26CuaCEjCdRB43qQ9b9zsEoFBHMGrC6dGul+H+uuPn9FbVOc +NSWumuxa22W6qdJAiJFq4RvwZrsbVnYs5V29Y4Y20IlVUj3siJpAny//UUHequW9 +A/U7RvVssDsEEbbKvCpfcS7STNJKU7GlgV5l5hMKN2xLs1bVG5OKiZN82Zh9r7e1 +m2irxu/ehu6rENxZN0gsfPE4vqoQpbRMNAJlCfq9a3k0PH0TOy5oOVJXPGTIDQab +E3PeAwECgYEA9wh4+bPgMuO04hsAqsoO0DJ9Cwa+BzoDPYOvENobDzmcMErSDLKb +ekl1ej+fBTHRHVaBkuOf/9neLjhjMLad1B+I5gLksqwoMh87odDRCCpkO/B20ln8 +IN6RFiMiNjOaZqjPCCUobgzjbaIz3I69lCQQnMNPwjllSgZs9Lh/PjUCgYEA8kZU +hhUN6ctHIo8ocnmqa4AUPbt2l4qOoBGHCMmhjthyft6g8y6cQlACVJzbco37MhjY +uCOhhOClyUS1tyfds3NXdzAxXPl8SwQJGvl3zqkDQG7/GhCh6AzvHhZR8u7UaweC +kVnAG87Ck6Qqo5ZNbjhMIUm0ujm2cdVd3vyV3fECgYEAmJSMHDck8GnCzLE+/T5m +XeQBZfEZKF+FptYSKId+lS3RMebUzHD5JVQAEqz/LHczoTpQOAkORzorSEMdyPXS +kDWWGfOJjG5XOXYfH/hZVADS/k6tJYnc9/RgitrSg8XlxSjZDz/cM/UT+CBqhf1I +TRrlg94DAoTu8gT8AT9/oE0CgYB5CSPO/JO/2jtGi6iUUC4QmKMEGDRuDt2kID2K +6ViaCY5hzY0xEHcmNdyEMvz7JO16oKkcjUhzHtwUSgxSXUtIDHaE6AGxRj6PJ4v4 ++uqcxxkFxq4Rcn/Acz2+lT4JlMFwWwci4Gi2O7w/kENxCHTUfLGj67OrWYvJIORN +s3iXsQKBgD1I+v+simBvKZKmozzv99EgGfxkRxmrUQsclg1V8a1VTNfE5X9oNaE5 +kjp+dTnwbtmFl3SHVdFUzX/L6FvQIQ9FIwWI2bsszPm4rw8FBeOvH+8lXwVhCwPs +y9him/PhdjBPX0zydDI+h+fmrxH/XbmryZcq1rNmEtFRHBsUs5jg +-----END RSA PRIVATE KEY----- diff --git a/net/gsmtp/gsmtp.go b/net/gsmtp/gsmtp.go index 91e28e386..0492b465f 100644 --- a/net/gsmtp/gsmtp.go +++ b/net/gsmtp/gsmtp.go @@ -4,7 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package gsmtp provides a SMTP client to access remote mail server. +// Package gsmtp provides a simple SMTP client to access remote mail server. // // Eg: // s := smtp.New("smtp.exmail.qq.com:25", "notify@a.com", "password") @@ -14,6 +14,7 @@ package gsmtp import ( "encoding/base64" "fmt" + "github.com/gogf/gf/util/gconv" "net/smtp" "strings" ) @@ -34,37 +35,45 @@ func New(address, username, password string) *SMTP { } } +var ( + // contentEncoding is the BASE64 encoding object for mail content. + contentEncoding = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") +) + // SendMail connects to the server at addr, switches to TLS if // possible, authenticates with the optional mechanism a if possible, -// and then sends an email from address from, to addresses to, with +// and then sends an email from address , to addresses , with // message msg. +// +// The parameter specifies the content type of the mail, eg: html. func (s *SMTP) SendMail(from, tos, subject, body string, contentType ...string) error { - server := "" - address := "" - - hp := strings.Split(s.Address, ":") - if (s.Address == "") || (len(hp) > 2) { - return fmt.Errorf("Server address is either empty or incorrect: %s", s.Address) + var ( + server = "" + address = "" + hp = strings.Split(s.Address, ":") + ) + if s.Address == "" || len(hp) > 2 { + return fmt.Errorf("server address is either empty or incorrect: %s", s.Address) } else if len(hp) == 1 { server = s.Address address = server + ":25" } else if len(hp) == 2 { if (hp[0] == "") || (hp[1] == "") { - return fmt.Errorf("Server address is either empty or incorrect: %s", s.Address) + return fmt.Errorf("server address is either empty or incorrect: %s", s.Address) } server = hp[0] address = s.Address } - - tosArr := []string{} - arr := strings.Split(tos, ";") + var ( + tosArr []string + arr = strings.Split(tos, ";") + ) for _, to := range arr { // TODO: replace with regex if strings.Contains(to, "@") { tosArr = append(tosArr, to) } } - if len(tosArr) == 0 { return fmt.Errorf("tos if invalid: %s", tos) } @@ -73,28 +82,27 @@ func (s *SMTP) SendMail(from, tos, subject, body string, contentType ...string) return fmt.Errorf("from is invalid: %s", from) } - b64 := base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") - - header := make(map[string]string) - header["From"] = from - header["To"] = strings.Join(tosArr, ";") - header["Subject"] = fmt.Sprintf("=?UTF-8?B?%s?=", b64.EncodeToString([]byte(subject))) - header["MIME-Version"] = "1.0" - - ct := "text/plain; charset=UTF-8" - if len(contentType) > 0 && contentType[0] == "html" { - ct = "text/html; charset=UTF-8" + header := map[string]string{ + "From": from, + "To": strings.Join(tosArr, ";"), + "Subject": fmt.Sprintf("=?UTF-8?B?%s?=", contentEncoding.EncodeToString(gconv.UnsafeStrToBytes(subject))), + "MIME-Version": "1.0", + "Content-Type": "text/plain; charset=UTF-8", + "Content-Transfer-Encoding": "base64", + } + if len(contentType) > 0 && contentType[0] == "html" { + header["Content-Type"] = "text/html; charset=UTF-8" } - - header["Content-Type"] = ct - header["Content-Transfer-Encoding"] = "base64" - message := "" for k, v := range header { message += fmt.Sprintf("%s: %s\r\n", k, v) } - message += "\r\n" + b64.EncodeToString([]byte(body)) - - auth := smtp.PlainAuth("", s.Username, s.Password, server) - return smtp.SendMail(address, auth, from, tosArr, []byte(message)) + message += "\r\n" + contentEncoding.EncodeToString(gconv.UnsafeStrToBytes(body)) + return smtp.SendMail( + address, + smtp.PlainAuth("", s.Username, s.Password, server), + from, + tosArr, + gconv.UnsafeStrToBytes(message), + ) } diff --git a/net/gsmtp/gsmtp_test.go b/net/gsmtp/gsmtp_test.go index fcd15350c..edae876ea 100644 --- a/net/gsmtp/gsmtp_test.go +++ b/net/gsmtp/gsmtp_test.go @@ -3,6 +3,7 @@ // This Source Code Form is subject to the terms of the MIT License. // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. + package gsmtp_test import ( diff --git a/net/gtcp/gtcp_unit_init_test.go b/net/gtcp/gtcp_unit_init_test.go index 47d002af4..974fa04db 100644 --- a/net/gtcp/gtcp_unit_init_test.go +++ b/net/gtcp/gtcp_unit_init_test.go @@ -15,7 +15,7 @@ var ( ) func init() { - for i := 9000; i <= 10000; i++ { + for i := 9000; i < 10000; i++ { ports.Append(i) } } diff --git a/os/gfile/gfile.go b/os/gfile/gfile.go index 0c0f2dcbd..bc5de9e89 100644 --- a/os/gfile/gfile.go +++ b/os/gfile/gfile.go @@ -42,7 +42,7 @@ var ( func init() { // Initialize internal package variable: tempDir. - if !Exists(tempDir) { + if Separator == "/" && !Exists(tempDir) { tempDir = os.TempDir() } // Initialize internal package variable: selfPath. diff --git a/os/gfsnotify/gfsnotify_watcher.go b/os/gfsnotify/gfsnotify_watcher.go index b9fc4aa66..8b3e07636 100644 --- a/os/gfsnotify/gfsnotify_watcher.go +++ b/os/gfsnotify/gfsnotify_watcher.go @@ -15,22 +15,25 @@ import ( ) // Add monitors with callback function to the watcher. -// The optional parameter specifies whether monitoring the recursively, which is true in default. +// The optional parameter specifies whether monitoring the recursively, +// which is true in default. func (w *Watcher) Add(path string, callbackFunc func(event *Event), recursive ...bool) (callback *Callback, err error) { return w.AddOnce("", path, callbackFunc, recursive...) } -// AddOnce monitors with callback function only once using unique name to the watcher. -// If AddOnce is called multiple times with the same parameter, is only added to monitor once. It returns error -// if it's called twice with the same . +// AddOnce monitors with callback function only once using unique name +// to the watcher. If AddOnce is called multiple times with the same parameter, +// is only added to monitor once. +// It returns error if it's called twice with the same . // -// The optional parameter specifies whether monitoring the recursively, which is true in default. +// The optional parameter specifies whether monitoring the recursively, +// which is true in default. func (w *Watcher) AddOnce(name, path string, callbackFunc func(event *Event), recursive ...bool) (callback *Callback, err error) { - w.nameSet.AddIfNotExistFuncLock(name, func() string { + w.nameSet.AddIfNotExistFuncLock(name, func() bool { // Firstly add the path to watcher. callback, err = w.addWithCallbackFunc(name, path, callbackFunc, recursive...) if err != nil { - return "" + return false } // If it's recursive adding, it then adds all sub-folders to the monitor. // NOTE: @@ -49,7 +52,10 @@ func (w *Watcher) AddOnce(name, path string, callbackFunc func(event *Event), re } } } - return name + if name == "" { + return false + } + return true }) return } diff --git a/os/gfsnotify/gfsnotify_z_unit_test.go b/os/gfsnotify/gfsnotify_z_unit_test.go index 46b8a39e0..341b02050 100644 --- a/os/gfsnotify/gfsnotify_z_unit_test.go +++ b/os/gfsnotify/gfsnotify_z_unit_test.go @@ -121,9 +121,9 @@ func TestWatcher_AddRemove(t *testing.T) { }) } -func TestWatcher_Callback(t *testing.T) { +func TestWatcher_Callback1(t *testing.T) { gtest.C(t, func(t *gtest.T) { - path1 := gfile.TempDir() + gfile.Separator + gconv.String(gtime.TimestampNano()) + path1 := gfile.TempDir(gtime.TimestampNanoStr()) gfile.PutContents(path1, "1") defer func() { gfile.Remove(path1) @@ -148,10 +148,13 @@ func TestWatcher_Callback(t *testing.T) { time.Sleep(100 * time.Millisecond) t.Assert(v.Val(), 3) }) +} + +func TestWatcher_Callback2(t *testing.T) { // multiple callbacks gtest.C(t, func(t *gtest.T) { - path1 := gfile.TempDir() + gfile.Separator + gconv.String(gtime.TimestampNano()) - gfile.PutContents(path1, "1") + path1 := gfile.TempDir(gtime.TimestampNanoStr()) + t.Assert(gfile.PutContents(path1, "1"), nil) defer func() { gfile.Remove(path1) }() @@ -174,7 +177,7 @@ func TestWatcher_Callback(t *testing.T) { t.AssertNE(callback1, nil) t.AssertNE(callback2, nil) - gfile.PutContents(path1, "2") + t.Assert(gfile.PutContents(path1, "2"), nil) time.Sleep(100 * time.Millisecond) t.Assert(v1.Val(), 2) t.Assert(v2.Val(), 2) @@ -182,7 +185,7 @@ func TestWatcher_Callback(t *testing.T) { v1.Set(3) v2.Set(3) gfsnotify.RemoveCallback(callback1.Id) - gfile.PutContents(path1, "3") + t.Assert(gfile.PutContents(path1, "3"), nil) time.Sleep(100 * time.Millisecond) t.Assert(v1.Val(), 3) t.Assert(v2.Val(), 2) diff --git a/os/glog/glog.go b/os/glog/glog.go index cb3a005d2..12c569fe0 100644 --- a/os/glog/glog.go +++ b/os/glog/glog.go @@ -8,18 +8,18 @@ package glog import ( - "io" - "github.com/gogf/gf/internal/cmdenv" "github.com/gogf/gf/os/grpool" ) var ( - // Default logger object, for package method usage + // Default logger object, for package method usage. logger = New() + // Goroutine pool for async logging output. // It uses only one asynchronize worker to ensure log sequence. asyncPool = grpool.New(1) + // defaultDebug enables debug level or not in default, // which can be configured using command option or system environment. defaultDebug = true @@ -41,99 +41,3 @@ func DefaultLogger() *Logger { func SetDefaultLogger(l *Logger) { logger = l } - -// SetPath sets the directory path for file logging. -func SetPath(path string) error { - return logger.SetPath(path) -} - -// GetPath returns the logging directory path for file logging. -// It returns empty string if no directory path set. -func GetPath() string { - return logger.GetPath() -} - -// SetFile sets the file name for file logging. -// Datetime pattern can be used in , eg: access-{Ymd}.log. -// The default file name pattern is: Y-m-d.log, eg: 2018-01-01.log -func SetFile(pattern string) { - logger.SetFile(pattern) -} - -// SetLevel sets the default logging level. -func SetLevel(level int) { - logger.SetLevel(level) -} - -// GetLevel returns the default logging level value. -func GetLevel() int { - return logger.GetLevel() -} - -// SetWriter sets the customized logging for logging. -// The object should implements the io.Writer interface. -// Developer can use customized logging to redirect logging output to another service, -// eg: kafka, mysql, mongodb, etc. -func SetWriter(writer io.Writer) { - logger.SetWriter(writer) -} - -// GetWriter returns the customized writer object, which implements the io.Writer interface. -// It returns nil if no customized writer set. -func GetWriter() io.Writer { - return logger.GetWriter() -} - -// SetDebug enables/disables the debug level for default logger. -// The debug level is enbaled in default. -func SetDebug(debug bool) { - logger.SetDebug(debug) -} - -// SetAsync enables/disables async logging output feature for default logger. -func SetAsync(enabled bool) { - logger.SetAsync(enabled) -} - -// SetStdoutPrint sets whether ouptput the logging contents to stdout, which is true in default. -func SetStdoutPrint(enabled bool) { - logger.SetStdoutPrint(enabled) -} - -// SetHeaderPrint sets whether output header of the logging contents, which is true in default. -func SetHeaderPrint(enabled bool) { - logger.SetHeaderPrint(enabled) -} - -// SetPrefix sets prefix string for every logging content. -// Prefix is part of header, which means if header output is shut, no prefix will be output. -func SetPrefix(prefix string) { - logger.SetPrefix(prefix) -} - -// SetFlags sets extra flags for logging output features. -func SetFlags(flags int) { - logger.SetFlags(flags) -} - -// GetFlags returns the flags of logger. -func GetFlags() int { - return logger.GetFlags() -} - -// PrintStack prints the caller stack, -// the optional parameter specify the skipped stack offset from the end point. -func PrintStack(skip ...int) { - logger.PrintStack(skip...) -} - -// GetStack returns the caller stack content, -// the optional parameter specify the skipped stack offset from the end point. -func GetStack(skip ...int) string { - return logger.GetStack(skip...) -} - -// SetStack enables/disables the stack feature in failure logging outputs. -func SetStack(enabled bool) { - logger.SetStack(enabled) -} diff --git a/os/glog/glog_config.go b/os/glog/glog_config.go index e41d0ce92..64c272c9b 100644 --- a/os/glog/glog_config.go +++ b/os/glog/glog_config.go @@ -6,6 +6,10 @@ package glog +import ( + "io" +) + // SetConfig set configurations for the logger. func SetConfig(config Config) error { return logger.SetConfig(config) @@ -15,3 +19,119 @@ func SetConfig(config Config) error { func SetConfigWithMap(m map[string]interface{}) error { return logger.SetConfigWithMap(m) } + +// SetPath sets the directory path for file logging. +func SetPath(path string) error { + return logger.SetPath(path) +} + +// GetPath returns the logging directory path for file logging. +// It returns empty string if no directory path set. +func GetPath() string { + return logger.GetPath() +} + +// SetFile sets the file name for file logging. +// Datetime pattern can be used in , eg: access-{Ymd}.log. +// The default file name pattern is: Y-m-d.log, eg: 2018-01-01.log +func SetFile(pattern string) { + logger.SetFile(pattern) +} + +// SetLevel sets the default logging level. +func SetLevel(level int) { + logger.SetLevel(level) +} + +// GetLevel returns the default logging level value. +func GetLevel() int { + return logger.GetLevel() +} + +// SetWriter sets the customized logging for logging. +// The object should implements the io.Writer interface. +// Developer can use customized logging to redirect logging output to another service, +// eg: kafka, mysql, mongodb, etc. +func SetWriter(writer io.Writer) { + logger.SetWriter(writer) +} + +// GetWriter returns the customized writer object, which implements the io.Writer interface. +// It returns nil if no customized writer set. +func GetWriter() io.Writer { + return logger.GetWriter() +} + +// SetDebug enables/disables the debug level for default logger. +// The debug level is enabled in default. +func SetDebug(debug bool) { + logger.SetDebug(debug) +} + +// SetAsync enables/disables async logging output feature for default logger. +func SetAsync(enabled bool) { + logger.SetAsync(enabled) +} + +// SetStdoutPrint sets whether ouptput the logging contents to stdout, which is true in default. +func SetStdoutPrint(enabled bool) { + logger.SetStdoutPrint(enabled) +} + +// SetHeaderPrint sets whether output header of the logging contents, which is true in default. +func SetHeaderPrint(enabled bool) { + logger.SetHeaderPrint(enabled) +} + +// SetPrefix sets prefix string for every logging content. +// Prefix is part of header, which means if header output is shut, no prefix will be output. +func SetPrefix(prefix string) { + logger.SetPrefix(prefix) +} + +// SetFlags sets extra flags for logging output features. +func SetFlags(flags int) { + logger.SetFlags(flags) +} + +// GetFlags returns the flags of logger. +func GetFlags() int { + return logger.GetFlags() +} + +// PrintStack prints the caller stack, +// the optional parameter specify the skipped stack offset from the end point. +func PrintStack(skip ...int) { + logger.PrintStack(skip...) +} + +// GetStack returns the caller stack content, +// the optional parameter specify the skipped stack offset from the end point. +func GetStack(skip ...int) string { + return logger.GetStack(skip...) +} + +// SetStack enables/disables the stack feature in failure logging outputs. +func SetStack(enabled bool) { + logger.SetStack(enabled) +} + +// SetLevelStr sets the logging level by level string. +func SetLevelStr(levelStr string) error { + return logger.SetLevelStr(levelStr) +} + +// SetLevelPrefix sets the prefix string for specified level. +func SetLevelPrefix(level int, prefix string) { + logger.SetLevelPrefix(level, prefix) +} + +// SetLevelPrefixes sets the level to prefix string mapping for the logger. +func SetLevelPrefixes(prefixes map[int]string) { + logger.SetLevelPrefixes(prefixes) +} + +// GetLevelPrefix returns the prefix string for specified level. +func GetLevelPrefix(level int) string { + return logger.GetLevelPrefix(level) +} diff --git a/os/glog/glog_logger_config.go b/os/glog/glog_logger_config.go index b7d1248cd..04d8e347b 100644 --- a/os/glog/glog_logger_config.go +++ b/os/glog/glog_logger_config.go @@ -82,7 +82,7 @@ func (l *Logger) SetConfigWithMap(m map[string]interface{}) error { } // The m now is a shallow copy of m. // A little tricky, isn't it? - m = gutil.CopyMap(m) + m = gutil.MapCopy(m) // Change string configuration to int value for level. levelKey, levelValue := gutil.MapPossibleItemByKey(m, "Level") if levelValue != nil { @@ -100,12 +100,11 @@ func (l *Logger) SetConfigWithMap(m map[string]interface{}) error { return errors.New(fmt.Sprintf(`invalid rotate size: %v`, rotateSizeValue)) } } - config := DefaultConfig() - err := gconv.Struct(m, &config) + err := gconv.Struct(m, &l.config) if err != nil { return err } - return l.SetConfig(config) + return l.SetConfig(l.config) } // SetDebug enables/disables the debug level for logger. diff --git a/os/gproc/gproc_process.go b/os/gproc/gproc_process.go index fd9a99bd4..b6062c8e3 100644 --- a/os/gproc/gproc_process.go +++ b/os/gproc/gproc_process.go @@ -9,6 +9,7 @@ package gproc import ( "errors" "fmt" + "github.com/gogf/gf/internal/intlog" "os" "os/exec" "runtime" @@ -116,9 +117,14 @@ func (p *Process) Kill() error { p.Manager.processes.Remove(p.Pid()) } if runtime.GOOS != "windows" { - p.Process.Release() + if err = p.Process.Release(); err != nil { + intlog.Error(err) + //return err + } } - p.Process.Wait() + _, err = p.Process.Wait() + intlog.Error(err) + //return err return nil } else { return err diff --git a/os/gview/gview.go b/os/gview/gview.go index a8fa0832f..88d948c4b 100644 --- a/os/gview/gview.go +++ b/os/gview/gview.go @@ -12,7 +12,6 @@ package gview import ( "github.com/gogf/gf/container/gmap" - "github.com/gogf/gf/i18n/gi18n" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf" @@ -28,21 +27,12 @@ type View struct { data map[string]interface{} // Global template variables. funcMap map[string]interface{} // Global template function map. fileCacheMap *gmap.StrAnyMap // File cache map. - defaultFile string // Default template file for parsing. - i18nManager *gi18n.Manager // I18n manager for this view. - delimiters []string // Custom template delimiters. config Config // Extra configuration for the view. } -// Params is type for template params. -type Params = map[string]interface{} - -// FuncMap is type for custom template functions. -type FuncMap = map[string]interface{} - -const ( - // Default template file for parsing. - defaultParsingFile = "index.html" +type ( + Params = map[string]interface{} // Params is type for template params. + FuncMap = map[string]interface{} // FuncMap is type for custom template functions. ) var ( @@ -60,9 +50,9 @@ func checkAndInitDefaultView() { // ParseContent parses the template content directly using the default view object // and returns the parsed content. -func ParseContent(content string, params Params) (string, error) { +func ParseContent(content string, params ...Params) (string, error) { checkAndInitDefaultView() - return defaultViewObj.ParseContent(content, params) + return defaultViewObj.ParseContent(content, params...) } // New returns a new view object. @@ -73,9 +63,7 @@ func New(path ...string) *View { data: make(map[string]interface{}), funcMap: make(map[string]interface{}), fileCacheMap: gmap.NewStrAnyMap(true), - defaultFile: defaultParsingFile, - i18nManager: gi18n.Instance(), - delimiters: make([]string, 2), + config: DefaultConfig(), } if len(path) > 0 && len(path[0]) > 0 { if err := view.SetPath(path[0]); err != nil { @@ -118,35 +106,36 @@ func New(path ...string) *View { "version": gf.VERSION, } // default build-in functions. - view.BindFunc("eq", view.funcEq) - view.BindFunc("ne", view.funcNe) - view.BindFunc("lt", view.funcLt) - view.BindFunc("le", view.funcLe) - view.BindFunc("gt", view.funcGt) - view.BindFunc("ge", view.funcGe) - view.BindFunc("text", view.funcText) + view.BindFuncMap(FuncMap{ + "eq": view.funcEq, + "ne": view.funcNe, + "lt": view.funcLt, + "le": view.funcLe, + "gt": view.funcGt, + "ge": view.funcGe, + "text": view.funcText, + "html": view.funcHtmlEncode, + "htmlencode": view.funcHtmlEncode, + "htmldecode": view.funcHtmlDecode, + "encode": view.funcHtmlEncode, + "decode": view.funcHtmlDecode, + "url": view.funcUrlEncode, + "urlencode": view.funcUrlEncode, + "urldecode": view.funcUrlDecode, + "date": view.funcDate, + "substr": view.funcSubStr, + "strlimit": view.funcStrLimit, + "concat": view.funcConcat, + "replace": view.funcReplace, + "compare": view.funcCompare, + "hidestr": view.funcHideStr, + "highlight": view.funcHighlight, + "toupper": view.funcToUpper, + "tolower": view.funcToLower, + "nl2br": view.funcNl2Br, + "include": view.funcInclude, + "dump": view.funcDump, + }) - view.BindFunc("html", view.funcHtmlEncode) - view.BindFunc("htmlencode", view.funcHtmlEncode) - view.BindFunc("htmldecode", view.funcHtmlDecode) - view.BindFunc("encode", view.funcHtmlEncode) - view.BindFunc("decode", view.funcHtmlDecode) - - view.BindFunc("url", view.funcUrlEncode) - view.BindFunc("urlencode", view.funcUrlEncode) - view.BindFunc("urldecode", view.funcUrlDecode) - view.BindFunc("date", view.funcDate) - view.BindFunc("substr", view.funcSubStr) - view.BindFunc("strlimit", view.funcStrLimit) - view.BindFunc("concat", view.funcConcat) - view.BindFunc("replace", view.funcReplace) - view.BindFunc("compare", view.funcCompare) - view.BindFunc("hidestr", view.funcHideStr) - view.BindFunc("highlight", view.funcHighlight) - view.BindFunc("toupper", view.funcToUpper) - view.BindFunc("tolower", view.funcToLower) - view.BindFunc("nl2br", view.funcNl2Br) - view.BindFunc("include", view.funcInclude) - view.BindFunc("dump", view.funcDump) return view } diff --git a/os/gview/gview_buildin.go b/os/gview/gview_buildin.go index 01896e35e..2fc455ba8 100644 --- a/os/gview/gview_buildin.go +++ b/os/gview/gview_buildin.go @@ -153,12 +153,12 @@ func (view *View) funcCompare(value1, value2 interface{}) int { // funcSubStr implements build-in template function: substr func (view *View) funcSubStr(start, end, str interface{}) string { - return gstr.SubStr(gconv.String(str), gconv.Int(start), gconv.Int(end)) + return gstr.SubStrRune(gconv.String(str), gconv.Int(start), gconv.Int(end)) } // funcStrLimit implements build-in template function: strlimit func (view *View) funcStrLimit(length, suffix, str interface{}) string { - return gstr.StrLimit(gconv.String(str), gconv.Int(length), gconv.String(suffix)) + return gstr.StrLimitRune(gconv.String(str), gconv.Int(length), gconv.String(suffix)) } // funcConcat implements build-in template function: concat diff --git a/os/gview/gview_config.go b/os/gview/gview_config.go index 417be07fa..6f1a08ac2 100644 --- a/os/gview/gview_config.go +++ b/os/gview/gview_config.go @@ -22,10 +22,25 @@ import ( // Config is the configuration object for template engine. type Config struct { Paths []string // Searching array for path, NOT concurrent-safe for performance purpose. - Data map[string]interface{} // Global template variables. + Data map[string]interface{} // Global template variables including configuration. DefaultFile string // Default template file for parsing. Delimiters []string // Custom template delimiters. AutoEncode bool // Automatically encodes and provides safe html output, which is good for avoiding XSS. + I18nManager *gi18n.Manager // I18n manager for the view. +} + +const ( + // Default template file for parsing. + defaultParsingFile = "index.html" +) + +// DefaultConfig creates and returns a configuration object with default configurations. +func DefaultConfig() Config { + return Config{ + DefaultFile: defaultParsingFile, + I18nManager: gi18n.Instance(), + Delimiters: make([]string, 2), + } } // SetConfig sets the configuration for view. @@ -63,19 +78,18 @@ func (view *View) SetConfigWithMap(m map[string]interface{}) error { } // The m now is a shallow copy of m. // A little tricky, isn't it? - m = gutil.CopyMap(m) + m = gutil.MapCopy(m) // Most common used configuration support for single view path. _, v1 := gutil.MapPossibleItemByKey(m, "paths") _, v2 := gutil.MapPossibleItemByKey(m, "path") if v1 == nil && v2 != nil { m["paths"] = []interface{}{v2} } - config := Config{} - err := gconv.Struct(m, &config) + err := gconv.Struct(m, &view.config) if err != nil { return err } - return view.SetConfig(config) + return view.SetConfig(view.config) } // SetPath sets the template directory path for template file search. @@ -199,13 +213,17 @@ func (view *View) Assign(key string, value interface{}) { // SetDefaultFile sets default template file for parsing. func (view *View) SetDefaultFile(file string) { - view.defaultFile = file + view.config.DefaultFile = file +} + +// GetDefaultFile returns default template file for parsing. +func (view *View) GetDefaultFile() string { + return view.config.DefaultFile } // SetDelimiters sets customized delimiters for template parsing. func (view *View) SetDelimiters(left, right string) { - view.delimiters[0] = left - view.delimiters[1] = right + view.config.Delimiters = []string{left, right} } // SetAutoEncode enables/disables automatically html encoding feature. @@ -237,5 +255,5 @@ func (view *View) BindFuncMap(funcMap FuncMap) { // SetI18n binds i18n manager to current view engine. func (view *View) SetI18n(manager *gi18n.Manager) { - view.i18nManager = manager + view.config.I18nManager = manager } diff --git a/os/gview/gview_i18n.go b/os/gview/gview_i18n.go index 0aa653053..0e5fa1116 100644 --- a/os/gview/gview_i18n.go +++ b/os/gview/gview_i18n.go @@ -10,14 +10,14 @@ import "github.com/gogf/gf/util/gconv" // i18nTranslate translate the content with i18n feature. func (view *View) i18nTranslate(content string, params Params) string { - if view.i18nManager != nil { + if view.config.I18nManager != nil { if v, ok := params["I18nLanguage"]; ok { language := gconv.String(v) if language != "" { - return view.i18nManager.T(content, language) + return view.config.I18nManager.T(content, language) } } - return view.i18nManager.T(content) + return view.config.I18nManager.T(content) } return content } diff --git a/os/gview/gview_parse.go b/os/gview/gview_parse.go index 19f7b1c9b..4184aef9d 100644 --- a/os/gview/gview_parse.go +++ b/os/gview/gview_parse.go @@ -17,6 +17,7 @@ import ( "github.com/gogf/gf/os/gmlock" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/gutil" htmltpl "html/template" "strconv" "strings" @@ -113,32 +114,9 @@ func (view *View) Parse(file string, params ...Params) (result string, err error // Note that the template variable assignment cannot change the value // of the existing or view.data because both variables are pointers. // It needs to merge the values of the two maps into a new map. - var variables map[string]interface{} - length := len(view.data) - if len(params) > 0 { - length += len(params[0]) - } - if length > 0 { - variables = make(map[string]interface{}, length) - } + variables := gutil.MapMergeCopy(params...) if len(view.data) > 0 { - if len(params) > 0 { - if variables == nil { - variables = make(map[string]interface{}) - } - for k, v := range params[0] { - variables[k] = v - } - for k, v := range view.data { - variables[k] = v - } - } else { - variables = view.data - } - } else { - if len(params) > 0 { - variables = params[0] - } + gutil.MapMerge(variables, view.data) } buffer := bytes.NewBuffer(nil) if view.config.AutoEncode { @@ -163,7 +141,7 @@ func (view *View) Parse(file string, params ...Params) (result string, err error // ParseDefault parses the default template file with params. func (view *View) ParseDefault(params ...Params) (result string, err error) { - return view.Parse(view.defaultFile, params...) + return view.Parse(view.config.DefaultFile, params...) } // ParseContent parses given template content with template variables @@ -174,12 +152,18 @@ func (view *View) ParseContent(content string, params ...Params) (string, error) return "", nil } err := (error)(nil) - key := fmt.Sprintf("%s_%v_%v", gCONTENT_TEMPLATE_NAME, view.delimiters, view.config.AutoEncode) + key := fmt.Sprintf("%s_%v_%v", gCONTENT_TEMPLATE_NAME, view.config.Delimiters, view.config.AutoEncode) tpl := templates.GetOrSetFuncLock(key, func() interface{} { if view.config.AutoEncode { - return htmltpl.New(gCONTENT_TEMPLATE_NAME).Delims(view.delimiters[0], view.delimiters[1]).Funcs(view.funcMap) + return htmltpl.New(gCONTENT_TEMPLATE_NAME).Delims( + view.config.Delimiters[0], + view.config.Delimiters[1], + ).Funcs(view.funcMap) } - return texttpl.New(gCONTENT_TEMPLATE_NAME).Delims(view.delimiters[0], view.delimiters[1]).Funcs(view.funcMap) + return texttpl.New(gCONTENT_TEMPLATE_NAME).Delims( + view.config.Delimiters[0], + view.config.Delimiters[1], + ).Funcs(view.funcMap) }) // Using memory lock to ensure concurrent safety for content parsing. hash := strconv.FormatUint(ghash.DJBHash64([]byte(content)), 10) @@ -196,32 +180,9 @@ func (view *View) ParseContent(content string, params ...Params) (string, error) // Note that the template variable assignment cannot change the value // of the existing or view.data because both variables are pointers. // It needs to merge the values of the two maps into a new map. - var variables map[string]interface{} - length := len(view.data) - if len(params) > 0 { - length += len(params[0]) - } - if length > 0 { - variables = make(map[string]interface{}, length) - } + variables := gutil.MapMergeCopy(params...) if len(view.data) > 0 { - if len(params) > 0 { - if variables == nil { - variables = make(map[string]interface{}) - } - for k, v := range params[0] { - variables[k] = v - } - for k, v := range view.data { - variables[k] = v - } - } else { - variables = view.data - } - } else { - if len(params) > 0 { - variables = params[0] - } + gutil.MapMerge(variables, view.data) } buffer := bytes.NewBuffer(nil) if view.config.AutoEncode { @@ -249,14 +210,20 @@ func (view *View) ParseContent(content string, params ...Params) (string, error) // if the template files under changes (recursively). func (view *View) getTemplate(filePath, folderPath, pattern string) (tpl interface{}, err error) { // Key for template cache. - key := fmt.Sprintf("%s_%v", filePath, view.delimiters) + key := fmt.Sprintf("%s_%v", filePath, view.config.Delimiters) result := templates.GetOrSetFuncLock(key, func() interface{} { // Do not use but the as the parameter for function New, // because when error occurs the will be printed out for error locating. if view.config.AutoEncode { - tpl = htmltpl.New(filePath).Delims(view.delimiters[0], view.delimiters[1]).Funcs(view.funcMap) + tpl = htmltpl.New(filePath).Delims( + view.config.Delimiters[0], + view.config.Delimiters[1], + ).Funcs(view.funcMap) } else { - tpl = texttpl.New(filePath).Delims(view.delimiters[0], view.delimiters[1]).Funcs(view.funcMap) + tpl = texttpl.New(filePath).Delims( + view.config.Delimiters[0], + view.config.Delimiters[1], + ).Funcs(view.funcMap) } // Firstly checking the resource manager. if !gres.IsEmpty() { diff --git a/test/gtest/gtest_t.go b/test/gtest/gtest_t.go index acdd8989c..06c677227 100644 --- a/test/gtest/gtest_t.go +++ b/test/gtest/gtest_t.go @@ -30,6 +30,11 @@ func (t *T) AssertNE(value, expect interface{}) { AssertNE(value, expect) } +// AssertNQ checks and NOT EQUAL, including their TYPES. +func (t *T) AssertNQ(value, expect interface{}) { + AssertNQ(value, expect) +} + // AssertGT checks is GREATER THAN . // Notice that, only string, integer and float types can be compared by AssertGT, // others are invalid. diff --git a/test/gtest/gtest_util.go b/test/gtest/gtest_util.go index 3216bd394..701d8250c 100644 --- a/test/gtest/gtest_util.go +++ b/test/gtest/gtest_util.go @@ -60,8 +60,10 @@ func Assert(value, expect interface{}) { } return } - strValue := gconv.String(value) - strExpect := gconv.String(expect) + var ( + strValue = gconv.String(value) + strExpect = gconv.String(expect) + ) if strValue != strExpect { panic(fmt.Sprintf(`[ASSERT] EXPECT %v == %v`, strValue, strExpect)) } @@ -105,13 +107,32 @@ func AssertNE(value, expect interface{}) { } return } - strValue := gconv.String(value) - strExpect := gconv.String(expect) + var ( + strValue = gconv.String(value) + strExpect = gconv.String(expect) + ) if strValue == strExpect { panic(fmt.Sprintf(`[ASSERT] EXPECT %v != %v`, strValue, strExpect)) } } +// AssertNQ checks and NOT EQUAL, including their TYPES. +func AssertNQ(value, expect interface{}) { + // Type assert. + t1 := reflect.TypeOf(value) + t2 := reflect.TypeOf(expect) + if t1 == t2 { + panic( + fmt.Sprintf( + `[ASSERT] EXPECT TYPE %v[%v] != %v[%v]`, + gconv.String(value), t1, gconv.String(expect), t2, + ), + ) + } + // Value assert. + AssertNE(value, expect) +} + // AssertGT checks is GREATER THAN . // Notice that, only string, integer and float types can be compared by AssertGT, // others are invalid. diff --git a/text/gstr/gstr.go b/text/gstr/gstr.go index e225f423a..21d6e207d 100644 --- a/text/gstr/gstr.go +++ b/text/gstr/gstr.go @@ -158,9 +158,35 @@ func IsNumeric(s string) bool { // SubStr returns a portion of string specified by the and parameters. func SubStr(str string, start int, length ...int) (substr string) { + lth := len(str) + + // Simple border checks. + if start < 0 { + start = 0 + } + if start >= lth { + start = lth + } + end := lth + if len(length) > 0 { + end = start + length[0] + if end < start { + end = lth + } + } + if end > lth { + end = lth + } + return str[start:end] +} + +// SubStrRune returns a portion of string specified by the and parameters. +// SubStrRune considers parameter as unicode string. +func SubStrRune(str string, start int, length ...int) (substr string) { // Converting to []rune to support unicode. rs := []rune(str) lth := len(rs) + // Simple border checks. if start < 0 { start = 0 @@ -181,10 +207,23 @@ func SubStr(str string, start int, length ...int) (substr string) { return string(rs[start:end]) } -// StrLimit returns a portion of string specified by parameters, -// if the length of is greater than , -// then the will be appended to the result string. +// StrLimit returns a portion of string specified by parameters, if the length +// of is greater than , then the will be appended to the result string. func StrLimit(str string, length int, suffix ...string) string { + if len(str) < length { + return str + } + addStr := "..." + if len(suffix) > 0 { + addStr = suffix[0] + } + return str[0:length] + addStr +} + +// StrLimitRune returns a portion of string specified by parameters, if the length +// of is greater than , then the will be appended to the result string. +// StrLimitRune considers parameter as unicode string. +func StrLimitRune(str string, length int, suffix ...string) string { rs := []rune(str) if len(rs) < length { return str @@ -255,6 +294,7 @@ func NumberFormat(number float64, decimals int, decPoint, thousandsSep string) s // Can be used to split a string into smaller chunks which is useful for // e.g. converting BASE64 string output to match RFC 2045 semantics. // It inserts end every chunkLen characters. +// It considers parameter and as unicode string. func ChunkSplit(body string, chunkLen int, end string) string { if end == "" { end = "\r\n" @@ -304,6 +344,7 @@ func HasSuffix(s, suffix string) bool { } // CountWords returns information about words' count used in a string. +// It considers parameter as unicode string. func CountWords(str string) map[string]int { m := make(map[string]int) buffer := bytes.NewBuffer(nil) @@ -324,6 +365,7 @@ func CountWords(str string) map[string]int { } // CountChars returns information about chars' count used in a string. +// It considers parameter as unicode string. func CountChars(str string, noSpace ...bool) map[string]int { m := make(map[string]int) countSpace := true @@ -340,16 +382,18 @@ func CountChars(str string, noSpace ...bool) map[string]int { } // WordWrap wraps a string to a given number of characters. -// TODO: Enable cut param, see http://php.net/manual/en/function.wordwrap.php. +// TODO: Enable cut parameter, see http://php.net/manual/en/function.wordwrap.php. func WordWrap(str string, width int, br string) string { if br == "" { br = "\n" } - init := make([]byte, 0, len(str)) - buf := bytes.NewBuffer(init) - var current int - var wordBuf, spaceBuf bytes.Buffer - for _, char := range str { + var ( + current int + wordBuf, spaceBuf bytes.Buffer + init = make([]byte, 0, len(str)) + buf = bytes.NewBuffer(init) + ) + for _, char := range []rune(str) { if char == '\n' { if wordBuf.Len() == 0 { if current+spaceBuf.Len() > width { @@ -399,7 +443,13 @@ func WordWrap(str string, width int, br string) string { } // RuneLen returns string length of unicode. +// Deprecated, use LenRune instead. func RuneLen(str string) int { + return LenRune(str) +} + +// LenRune returns string length of unicode. +func LenRune(str string) int { return utf8.RuneCountInString(str) } @@ -423,6 +473,7 @@ func Str(haystack string, needle string) string { } // Shuffle randomly shuffles a string. +// It considers parameter as unicode string. func Shuffle(str string) string { runes := []rune(str) s := make([]rune, len(runes)) @@ -502,19 +553,22 @@ func Ord(char string) int { } // HideStr replaces part of the the string to by from the . +// It considers parameter as unicode string. func HideStr(str string, percent int, hide string) string { array := strings.Split(str, "@") if len(array) > 1 { str = array[0] } - rs := []rune(str) - length := len(rs) - mid := math.Floor(float64(length / 2)) - hideLen := int(math.Floor(float64(length) * (float64(percent) / 100))) - start := int(mid - math.Floor(float64(hideLen)/2)) - hideStr := []rune("") - hideRune := []rune(hide) - for i := 0; i < int(hideLen); i++ { + var ( + rs = []rune(str) + length = len(rs) + mid = math.Floor(float64(length / 2)) + hideLen = int(math.Floor(float64(length) * (float64(percent) / 100))) + start = int(mid - math.Floor(float64(hideLen)/2)) + hideStr = []rune("") + hideRune = []rune(hide) + ) + for i := 0; i < hideLen; i++ { hideStr = append(hideStr, hideRune...) } buffer := bytes.NewBuffer(nil) @@ -529,6 +583,7 @@ func HideStr(str string, percent int, hide string) string { // Nl2Br inserts HTML line breaks(
|
) before all newlines in a string: // \n\r, \r\n, \r, \n. +// It considers parameter as unicode string. func Nl2Br(str string, isXhtml ...bool) string { r, n, runes := '\r', '\n', []rune(str) var br []byte diff --git a/text/gstr/gstr_pos.go b/text/gstr/gstr_pos.go index 8a55bc89f..bf76e756a 100644 --- a/text/gstr/gstr_pos.go +++ b/text/gstr/gstr_pos.go @@ -31,6 +31,15 @@ func Pos(haystack, needle string, startOffset ...int) int { return pos + offset } +// PosRune acts like function Pos but considers and as unicode string. +func PosRune(haystack, needle string, startOffset ...int) int { + pos := Pos(haystack, needle, startOffset...) + if pos < 3 { + return pos + } + return len([]rune(haystack[:pos])) +} + // PosI returns the position of the first occurrence of // in from , case-insensitively. // It returns -1, if not found. @@ -54,6 +63,15 @@ func PosI(haystack, needle string, startOffset ...int) int { return pos + offset } +// PosIRune acts like function PosI but considers and as unicode string. +func PosIRune(haystack, needle string, startOffset ...int) int { + pos := PosI(haystack, needle, startOffset...) + if pos < 3 { + return pos + } + return len([]rune(haystack[:pos])) +} + // PosR returns the position of the last occurrence of // in from , case-sensitively. // It returns -1, if not found. @@ -79,6 +97,15 @@ func PosR(haystack, needle string, startOffset ...int) int { return pos } +// PosRRune acts like function PosR but considers and as unicode string. +func PosRRune(haystack, needle string, startOffset ...int) int { + pos := PosR(haystack, needle, startOffset...) + if pos < 3 { + return pos + } + return len([]rune(haystack[:pos])) +} + // PosRI returns the position of the last occurrence of // in from , case-insensitively. // It returns -1, if not found. @@ -103,3 +130,12 @@ func PosRI(haystack, needle string, startOffset ...int) int { } return pos } + +// PosRIRune acts like function PosRI but considers and as unicode string. +func PosRIRune(haystack, needle string, startOffset ...int) int { + pos := PosRI(haystack, needle, startOffset...) + if pos < 3 { + return pos + } + return len([]rune(haystack[:pos])) +} diff --git a/text/gstr/gstr_z_unit_basic_test.go b/text/gstr/gstr_z_unit_basic_test.go index d283f3a5c..401011ff2 100644 --- a/text/gstr/gstr_z_unit_basic_test.go +++ b/text/gstr/gstr_z_unit_basic_test.go @@ -137,23 +137,43 @@ func Test_IsNumeric(t *testing.T) { func Test_SubStr(t *testing.T) { gtest.C(t, func(t *gtest.T) { t.Assert(gstr.SubStr("我爱GoFrame", 0), "我爱GoFrame") - t.Assert(gstr.SubStr("我爱GoFrame", 2), "GoFrame") - t.Assert(gstr.SubStr("我爱GoFrame", 2, 2), "Go") + t.Assert(gstr.SubStr("我爱GoFrame", 6), "GoFrame") + t.Assert(gstr.SubStr("我爱GoFrame", 6, 2), "Go") t.Assert(gstr.SubStr("我爱GoFrame", -1, 30), "我爱GoFrame") t.Assert(gstr.SubStr("我爱GoFrame", 30, 30), "") }) } +func Test_SubStrRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + t.Assert(gstr.SubStrRune("我爱GoFrame", 0), "我爱GoFrame") + t.Assert(gstr.SubStrRune("我爱GoFrame", 2), "GoFrame") + t.Assert(gstr.SubStrRune("我爱GoFrame", 2, 2), "Go") + t.Assert(gstr.SubStrRune("我爱GoFrame", -1, 30), "我爱GoFrame") + t.Assert(gstr.SubStrRune("我爱GoFrame", 30, 30), "") + }) +} + func Test_StrLimit(t *testing.T) { gtest.C(t, func(t *gtest.T) { - t.Assert(gstr.StrLimit("我爱GoFrame", 2), "我爱...") - t.Assert(gstr.StrLimit("我爱GoFrame", 2, ""), "我爱") - t.Assert(gstr.StrLimit("我爱GoFrame", 2, "**"), "我爱**") - t.Assert(gstr.StrLimit("我爱GoFrame", 4, ""), "我爱Go") + t.Assert(gstr.StrLimit("我爱GoFrame", 6), "我爱...") + t.Assert(gstr.StrLimit("我爱GoFrame", 6, ""), "我爱") + t.Assert(gstr.StrLimit("我爱GoFrame", 6, "**"), "我爱**") + t.Assert(gstr.StrLimit("我爱GoFrame", 8, ""), "我爱Go") t.Assert(gstr.StrLimit("*", 4, ""), "*") }) } +func Test_StrLimitRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + t.Assert(gstr.StrLimitRune("我爱GoFrame", 2), "我爱...") + t.Assert(gstr.StrLimitRune("我爱GoFrame", 2, ""), "我爱") + t.Assert(gstr.StrLimitRune("我爱GoFrame", 2, "**"), "我爱**") + t.Assert(gstr.StrLimitRune("我爱GoFrame", 4, ""), "我爱Go") + t.Assert(gstr.StrLimitRune("*", 4, ""), "*") + }) +} + func Test_HasPrefix(t *testing.T) { gtest.C(t, func(t *gtest.T) { t.Assert(gstr.HasPrefix("我爱GoFrame", "我爱"), true) @@ -247,6 +267,7 @@ func Test_WordWrap(t *testing.T) { gtest.C(t, func(t *gtest.T) { t.Assert(gstr.WordWrap("12 34", 2, "
"), "12
34") t.Assert(gstr.WordWrap("12 34", 2, "\n"), "12\n34") + t.Assert(gstr.WordWrap("我爱 GF", 2, "\n"), "我爱\nGF") t.Assert(gstr.WordWrap("A very long woooooooooooooooooord. and something", 7, "
"), "A very
long
woooooooooooooooooord.
and
something") }) diff --git a/text/gstr/gstr_z_unit_pos_test.go b/text/gstr/gstr_z_unit_pos_test.go index fdd4edecc..68288f3cf 100644 --- a/text/gstr/gstr_z_unit_pos_test.go +++ b/text/gstr/gstr_z_unit_pos_test.go @@ -23,6 +23,28 @@ func Test_Pos(t *testing.T) { t.Assert(gstr.Pos(s1, "abd", 0), -1) t.Assert(gstr.Pos(s1, "e", -4), 11) }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.Pos(s1, "爱"), 3) + t.Assert(gstr.Pos(s1, "C"), 6) + t.Assert(gstr.Pos(s1, "China"), 6) + }) +} + +func Test_PosRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s1 := "abcdEFGabcdefg" + t.Assert(gstr.PosRune(s1, "ab"), 0) + t.Assert(gstr.PosRune(s1, "ab", 2), 7) + t.Assert(gstr.PosRune(s1, "abd", 0), -1) + t.Assert(gstr.PosRune(s1, "e", -4), 11) + }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosRune(s1, "爱"), 1) + t.Assert(gstr.PosRune(s1, "C"), 2) + t.Assert(gstr.PosRune(s1, "China"), 2) + }) } func Test_PosI(t *testing.T) { @@ -34,6 +56,29 @@ func Test_PosI(t *testing.T) { t.Assert(gstr.PosI(s1, "abd", 0), -1) t.Assert(gstr.PosI(s1, "E", -4), 11) }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosI(s1, "爱"), 3) + t.Assert(gstr.PosI(s1, "c"), 6) + t.Assert(gstr.PosI(s1, "china"), 6) + }) +} + +func Test_PosIRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s1 := "abcdEFGabcdefg" + t.Assert(gstr.PosIRune(s1, "zz"), -1) + t.Assert(gstr.PosIRune(s1, "ab"), 0) + t.Assert(gstr.PosIRune(s1, "ef", 2), 4) + t.Assert(gstr.PosIRune(s1, "abd", 0), -1) + t.Assert(gstr.PosIRune(s1, "E", -4), 11) + }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosIRune(s1, "爱"), 1) + t.Assert(gstr.PosIRune(s1, "c"), 2) + t.Assert(gstr.PosIRune(s1, "china"), 2) + }) } func Test_PosR(t *testing.T) { @@ -47,6 +92,31 @@ func Test_PosR(t *testing.T) { t.Assert(gstr.PosR(s1, "abd", 0), -1) t.Assert(gstr.PosR(s1, "e", -4), -1) }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosR(s1, "爱"), 3) + t.Assert(gstr.PosR(s1, "C"), 6) + t.Assert(gstr.PosR(s1, "China"), 6) + }) +} + +func Test_PosRRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s1 := "abcdEFGabcdefg" + s2 := "abcdEFGz1cdeab" + t.Assert(gstr.PosRRune(s1, "zz"), -1) + t.Assert(gstr.PosRRune(s1, "ab"), 7) + t.Assert(gstr.PosRRune(s2, "ab", -2), 0) + t.Assert(gstr.PosRRune(s1, "ef"), 11) + t.Assert(gstr.PosRRune(s1, "abd", 0), -1) + t.Assert(gstr.PosRRune(s1, "e", -4), -1) + }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosRRune(s1, "爱"), 1) + t.Assert(gstr.PosRRune(s1, "C"), 2) + t.Assert(gstr.PosRRune(s1, "China"), 2) + }) } func Test_PosRI(t *testing.T) { @@ -60,4 +130,29 @@ func Test_PosRI(t *testing.T) { t.Assert(gstr.PosRI(s1, "abd", 0), -1) t.Assert(gstr.PosRI(s1, "e", -5), 4) }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosRI(s1, "爱"), 3) + t.Assert(gstr.PosRI(s1, "C"), 19) + t.Assert(gstr.PosRI(s1, "China"), 6) + }) +} + +func Test_PosRIRune(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s1 := "abcdEFGabcdefg" + s2 := "abcdEFGz1cdeab" + t.Assert(gstr.PosRIRune(s1, "zz"), -1) + t.Assert(gstr.PosRIRune(s1, "AB"), 7) + t.Assert(gstr.PosRIRune(s2, "AB", -2), 0) + t.Assert(gstr.PosRIRune(s1, "EF"), 11) + t.Assert(gstr.PosRIRune(s1, "abd", 0), -1) + t.Assert(gstr.PosRIRune(s1, "e", -5), 4) + }) + gtest.C(t, func(t *gtest.T) { + s1 := "我爱China very much" + t.Assert(gstr.PosRIRune(s1, "爱"), 1) + t.Assert(gstr.PosRIRune(s1, "C"), 15) + t.Assert(gstr.PosRIRune(s1, "China"), 2) + }) } diff --git a/util/gconv/gconv.go b/util/gconv/gconv.go index e4449d49e..594d21b39 100644 --- a/util/gconv/gconv.go +++ b/util/gconv/gconv.go @@ -231,35 +231,35 @@ func String(i interface{}) string { // If the variable implements the String() interface, // then use that interface to perform the conversion return f.String() - } else if f, ok := value.(apiError); ok { + } + if f, ok := value.(apiError); ok { // If the variable implements the Error() interface, // then use that interface to perform the conversion return f.Error() + } + // Reflect checks. + rv := reflect.ValueOf(value) + kind := rv.Kind() + switch kind { + case reflect.Chan, + reflect.Map, + reflect.Slice, + reflect.Func, + reflect.Ptr, + reflect.Interface, + reflect.UnsafePointer: + if rv.IsNil() { + return "" + } + } + if kind == reflect.Ptr { + return String(rv.Elem().Interface()) + } + // Finally we use json.Marshal to convert. + if jsonContent, err := json.Marshal(value); err != nil { + return fmt.Sprint(value) } else { - // Reflect checks. - rv := reflect.ValueOf(value) - kind := rv.Kind() - switch kind { - case reflect.Chan, - reflect.Map, - reflect.Slice, - reflect.Func, - reflect.Ptr, - reflect.Interface, - reflect.UnsafePointer: - if rv.IsNil() { - return "" - } - } - if kind == reflect.Ptr { - return String(rv.Elem().Interface()) - } - // Finally we use json.Marshal to convert. - if jsonContent, err := json.Marshal(value); err != nil { - return fmt.Sprint(value) - } else { - return string(jsonContent) - } + return string(jsonContent) } } } diff --git a/util/gconv/gconv_map.go b/util/gconv/gconv_map.go index 8631a7d72..b7d6675ab 100644 --- a/util/gconv/gconv_map.go +++ b/util/gconv/gconv_map.go @@ -45,192 +45,211 @@ func doMapConvert(value interface{}, recursive bool, tags ...string) map[string] if value == nil { return nil } - if r, ok := value.(map[string]interface{}); ok { + + // Assert the common combination of types, and finally it uses reflection. + m := make(map[string]interface{}) + switch r := value.(type) { + case string: + if len(r) > 0 && r[0] == '{' && r[len(r)-1] == '}' { + if err := json.Unmarshal([]byte(r), &m); err != nil { + return nil + } + } else { + return nil + } + case []byte: + if len(r) > 0 && r[0] == '{' && r[len(r)-1] == '}' { + if err := json.Unmarshal(r, &m); err != nil { + return nil + } + } else { + return nil + } + case map[interface{}]interface{}: + for k, v := range r { + m[String(k)] = v + } + case map[interface{}]string: + for k, v := range r { + m[String(k)] = v + } + case map[interface{}]int: + for k, v := range r { + m[String(k)] = v + } + case map[interface{}]uint: + for k, v := range r { + m[String(k)] = v + } + case map[interface{}]float32: + for k, v := range r { + m[String(k)] = v + } + case map[interface{}]float64: + for k, v := range r { + m[String(k)] = v + } + case map[string]bool: + for k, v := range r { + m[k] = v + } + case map[string]int: + for k, v := range r { + m[k] = v + } + case map[string]uint: + for k, v := range r { + m[k] = v + } + case map[string]float32: + for k, v := range r { + m[k] = v + } + case map[string]float64: + for k, v := range r { + m[k] = v + } + case map[string]interface{}: return r - } else { - // Assert the common combination of types, and finally it uses reflection. - m := make(map[string]interface{}) - switch r := value.(type) { - case string: - if len(r) > 0 && r[0] == '{' && r[len(r)-1] == '}' { - if err := json.Unmarshal([]byte(r), &m); err != nil { - return nil + case map[int]interface{}: + for k, v := range r { + m[String(k)] = v + } + case map[int]string: + for k, v := range r { + m[String(k)] = v + } + case map[uint]string: + for k, v := range r { + m[String(k)] = v + } + // Not a common type, it then uses reflection for conversion. + default: + rv := reflect.ValueOf(value) + kind := rv.Kind() + // If it is a pointer, we should find its real data type. + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + // If is type of array, it converts the value of even number index as its key and + // the value of odd number index as its corresponding value. + // Eg: + // []string{"k1","v1","k2","v2"} => map[string]interface{}{"k1":"v1", "k2":"v2"} + // []string{"k1","v1","k2"} => map[string]interface{}{"k1":"v1", "k2":nil} + case reflect.Slice, reflect.Array: + length := rv.Len() + for i := 0; i < length; i += 2 { + if i+1 < length { + m[String(rv.Index(i).Interface())] = rv.Index(i + 1).Interface() + } else { + m[String(rv.Index(i).Interface())] = nil } - } else { - return nil } - case []byte: - if len(r) > 0 && r[0] == '{' && r[len(r)-1] == '}' { - if err := json.Unmarshal(r, &m); err != nil { - return nil + case reflect.Map: + ks := rv.MapKeys() + for _, k := range ks { + m[String(k.Interface())] = rv.MapIndex(k).Interface() + } + case reflect.Struct: + // Map converting interface check. + if v, ok := value.(apiMapStrAny); ok { + return v.MapStrAny() + } + // Using reflect for converting. + var ( + rtField reflect.StructField + rvField reflect.Value + rvKind reflect.Kind + rt = rv.Type() + name = "" + tagArray = structTagPriority + ) + switch len(tags) { + case 0: + // No need handle. + case 1: + tagArray = append(strings.Split(tags[0], ","), structTagPriority...) + default: + tagArray = append(tags, structTagPriority...) + } + for i := 0; i < rv.NumField(); i++ { + rtField = rt.Field(i) + rvField = rv.Field(i) + // Only convert the public attributes. + fieldName := rtField.Name + if !utils.IsLetterUpper(fieldName[0]) { + continue } - } else { - return nil - } - case map[interface{}]interface{}: - for k, v := range r { - m[String(k)] = v - } - case map[interface{}]string: - for k, v := range r { - m[String(k)] = v - } - case map[interface{}]int: - for k, v := range r { - m[String(k)] = v - } - case map[interface{}]uint: - for k, v := range r { - m[String(k)] = v - } - case map[interface{}]float32: - for k, v := range r { - m[String(k)] = v - } - case map[interface{}]float64: - for k, v := range r { - m[String(k)] = v - } - case map[string]bool: - for k, v := range r { - m[k] = v - } - case map[string]int: - for k, v := range r { - m[k] = v - } - case map[string]uint: - for k, v := range r { - m[k] = v - } - case map[string]float32: - for k, v := range r { - m[k] = v - } - case map[string]float64: - for k, v := range r { - m[k] = v - } - case map[int]interface{}: - for k, v := range r { - m[String(k)] = v - } - case map[int]string: - for k, v := range r { - m[String(k)] = v - } - case map[uint]string: - for k, v := range r { - m[String(k)] = v - } - // Not a common type, it then uses reflection for conversion. - default: - rv := reflect.ValueOf(value) - kind := rv.Kind() - // If it is a pointer, we should find its real data type. - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - // If is type of array, it converts the value of even number index as its key and - // the value of odd number index as its corresponding value. - // Eg: - // []string{"k1","v1","k2","v2"} => map[string]interface{}{"k1":"v1", "k2":"v2"} - // []string{"k1","v1","k2"} => map[string]interface{}{"k1":"v1", "k2":nil} - case reflect.Slice, reflect.Array: - length := rv.Len() - for i := 0; i < length; i += 2 { - if i+1 < length { - m[String(rv.Index(i).Interface())] = rv.Index(i + 1).Interface() - } else { - m[String(rv.Index(i).Interface())] = nil + name = "" + fieldTag := rtField.Tag + for _, tag := range tagArray { + if name = fieldTag.Get(tag); name != "" { + break } } - case reflect.Map: - ks := rv.MapKeys() - for _, k := range ks { - m[String(k.Interface())] = rv.MapIndex(k).Interface() - } - case reflect.Struct: - // Map converting interface check. - if v, ok := value.(apiMapStrAny); ok { - return v.MapStrAny() - } - rt := rv.Type() - name := "" - tagArray := structTagPriority - switch len(tags) { - case 0: - // No need handle. - case 1: - tagArray = append(strings.Split(tags[0], ","), structTagPriority...) - default: - tagArray = append(tags, structTagPriority...) - } - var rtField reflect.StructField - var rvField reflect.Value - var rvKind reflect.Kind - for i := 0; i < rv.NumField(); i++ { - rtField = rt.Field(i) - rvField = rv.Field(i) - // Only convert the public attributes. - fieldName := rtField.Name - if !utils.IsLetterUpper(fieldName[0]) { + if name == "" { + name = fieldName + } else { + // Support json tag feature: -, omitempty + name = strings.TrimSpace(name) + if name == "-" { continue } - name = "" - fieldTag := rtField.Tag - for _, tag := range tagArray { - if name = fieldTag.Get(tag); name != "" { - break - } - } - if name == "" { - name = strings.TrimSpace(fieldName) - } else { - // Support json tag feature: -, omitempty - name = strings.TrimSpace(name) - if name == "-" { - continue - } - array := strings.Split(name, ",") - if len(array) > 1 { - switch strings.TrimSpace(array[1]) { - case "omitempty": - if empty.IsEmpty(rvField.Interface()) { - continue - } else { - name = strings.TrimSpace(array[0]) - } - default: + array := strings.Split(name, ",") + if len(array) > 1 { + switch strings.TrimSpace(array[1]) { + case "omitempty": + if empty.IsEmpty(rvField.Interface()) { + continue + } else { name = strings.TrimSpace(array[0]) } + default: + name = strings.TrimSpace(array[0]) } } - if recursive { + } + if recursive { + rvKind = rvField.Kind() + if rvKind == reflect.Ptr { + rvField = rvField.Elem() rvKind = rvField.Kind() - if rvKind == reflect.Ptr { - rvField = rvField.Elem() - rvKind = rvField.Kind() - } - if rvKind == reflect.Struct { + } + if rvKind == reflect.Struct { + hasNoTag := name == fieldName + if hasNoTag && rtField.Anonymous { + // It means this attribute field has no tag. + // Overwrite the attribute with sub-struct attribute fields. for k, v := range doMapConvert(rvField.Interface(), recursive, tags...) { m[k] = v } } else { - m[name] = rvField.Interface() + // It means this attribute field has desired tag. + m[name] = doMapConvert(rvField.Interface(), recursive, tags...) } + } else { + if rvField.IsValid() { + m[name] = rvField.Interface() + } else { + m[name] = nil + } + } + } else { + if rvField.IsValid() { m[name] = rvField.Interface() + } else { + m[name] = nil } } - default: - return nil } + default: + return nil } - return m } + return m } // MapStrStr converts to map[string]string. diff --git a/util/gconv/gconv_slice_any.go b/util/gconv/gconv_slice_any.go index dd52d09fb..8da1d5d5f 100644 --- a/util/gconv/gconv_slice_any.go +++ b/util/gconv/gconv_slice_any.go @@ -11,16 +11,16 @@ import ( "reflect" ) +// apiInterfaces is used for type assert api for Interfaces. +type apiInterfaces interface { + Interfaces() []interface{} +} + // SliceAny is alias of Interfaces. func SliceAny(i interface{}) []interface{} { return Interfaces(i) } -// Type assert api for Interfaces. -type apiInterfaces interface { - Interfaces() []interface{} -} - // Interfaces converts to []interface{}. func Interfaces(i interface{}) []interface{} { if i == nil { diff --git a/util/gconv/gconv_slice_float.go b/util/gconv/gconv_slice_float.go index fcf039a77..eac270fc8 100644 --- a/util/gconv/gconv_slice_float.go +++ b/util/gconv/gconv_slice_float.go @@ -6,6 +6,11 @@ package gconv +// apiFloats is used for type assert api for Floats. +type apiFloats interface { + Floats() []float64 +} + // SliceFloat is alias of Floats. func SliceFloat(i interface{}) []float64 { return Floats(i) @@ -31,85 +36,89 @@ func Float32s(i interface{}) []float32 { if i == nil { return nil } - if r, ok := i.([]float32); ok { - return r - } else { - var array []float32 - switch value := i.(type) { - case []string: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []int: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []int8: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []int16: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []int32: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []int64: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []uint: - for _, v := range value { - array = append(array, Float32(v)) - } - case []uint8: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []uint16: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []uint32: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []uint64: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []bool: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []float64: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - case []interface{}: - array = make([]float32, len(value)) - for k, v := range value { - array[k] = Float32(v) - } - default: - return []float32{Float32(i)} + var array []float32 + switch value := i.(type) { + case []string: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) } - return array + case []int: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []int8: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []int16: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []int32: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []int64: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []uint: + for _, v := range value { + array = append(array, Float32(v)) + } + case []uint8: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []uint16: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []uint32: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []uint64: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []bool: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []float32: + array = value + case []float64: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + case []interface{}: + array = make([]float32, len(value)) + for k, v := range value { + array[k] = Float32(v) + } + default: + if v, ok := i.(apiFloats); ok { + return Float32s(v.Floats()) + } + if v, ok := i.(apiInterfaces); ok { + return Float32s(v.Interfaces()) + } + return []float32{Float32(i)} } + return array } // Float64s converts to []float64. @@ -117,83 +126,88 @@ func Float64s(i interface{}) []float64 { if i == nil { return nil } - if r, ok := i.([]float64); ok { - return r - } else { - var array []float64 - switch value := i.(type) { - case []string: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []int: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []int8: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []int16: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []int32: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []int64: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []uint: - for _, v := range value { - array = append(array, Float64(v)) - } - case []uint8: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []uint16: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []uint32: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []uint64: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []bool: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []float32: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - case []interface{}: - array = make([]float64, len(value)) - for k, v := range value { - array[k] = Float64(v) - } - default: - return []float64{Float64(i)} + var array []float64 + switch value := i.(type) { + case []string: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) } - return array + case []int: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []int8: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []int16: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []int32: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []int64: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []uint: + for _, v := range value { + array = append(array, Float64(v)) + } + case []uint8: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []uint16: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []uint32: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []uint64: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []bool: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []float32: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + case []float64: + array = value + case []interface{}: + array = make([]float64, len(value)) + for k, v := range value { + array[k] = Float64(v) + } + default: + if v, ok := i.(apiFloats); ok { + return v.Floats() + } + if v, ok := i.(apiInterfaces); ok { + return Floats(v.Interfaces()) + } + return []float64{Float64(i)} } + return array + } diff --git a/util/gconv/gconv_slice_int.go b/util/gconv/gconv_slice_int.go index a6bb81567..64fe50517 100644 --- a/util/gconv/gconv_slice_int.go +++ b/util/gconv/gconv_slice_int.go @@ -6,6 +6,11 @@ package gconv +// apiInts is used for type assert api for Ints. +type apiInts interface { + Ints() []int +} + // SliceInt is alias of Ints. func SliceInt(i interface{}) []int { return Ints(i) @@ -26,95 +31,99 @@ func Ints(i interface{}) []int { if i == nil { return nil } - if r, ok := i.([]int); ok { - return r - } else { - var array []int - switch value := i.(type) { - case []string: - array = make([]int, len(value)) - for k, v := range value { - array[k] = Int(v) - } - case []int8: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []int16: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []int32: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []int64: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []uint: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []uint8: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []uint16: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []uint32: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []uint64: - array = make([]int, len(value)) - for k, v := range value { - array[k] = int(v) - } - case []bool: - array = make([]int, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]int, len(value)) - for k, v := range value { - array[k] = Int(v) - } - case []float64: - array = make([]int, len(value)) - for k, v := range value { - array[k] = Int(v) - } - case []interface{}: - array = make([]int, len(value)) - for k, v := range value { - array[k] = Int(v) - } - case [][]byte: - array = make([]int, len(value)) - for k, v := range value { - array[k] = Int(v) - } - default: - return []int{Int(i)} + var array []int + switch value := i.(type) { + case []string: + array = make([]int, len(value)) + for k, v := range value { + array[k] = Int(v) } - return array + case []int: + array = value + case []int8: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []int16: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []int32: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []int64: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []uint: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []uint8: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []uint16: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []uint32: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []uint64: + array = make([]int, len(value)) + for k, v := range value { + array[k] = int(v) + } + case []bool: + array = make([]int, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]int, len(value)) + for k, v := range value { + array[k] = Int(v) + } + case []float64: + array = make([]int, len(value)) + for k, v := range value { + array[k] = Int(v) + } + case []interface{}: + array = make([]int, len(value)) + for k, v := range value { + array[k] = Int(v) + } + case [][]byte: + array = make([]int, len(value)) + for k, v := range value { + array[k] = Int(v) + } + default: + if v, ok := i.(apiInts); ok { + return v.Ints() + } + if v, ok := i.(apiInterfaces); ok { + return Ints(v.Interfaces()) + } + return []int{Int(i)} } + return array } // Int32s converts to []int32. @@ -122,95 +131,99 @@ func Int32s(i interface{}) []int32 { if i == nil { return nil } - if r, ok := i.([]int32); ok { - return r - } else { - var array []int32 - switch value := i.(type) { - case []string: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = Int32(v) - } - case []int: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []int8: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []int16: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []int64: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []uint: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []uint8: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []uint16: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []uint32: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []uint64: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = int32(v) - } - case []bool: - array = make([]int32, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = Int32(v) - } - case []float64: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = Int32(v) - } - case []interface{}: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = Int32(v) - } - case [][]byte: - array = make([]int32, len(value)) - for k, v := range value { - array[k] = Int32(v) - } - default: - return []int32{Int32(i)} + var array []int32 + switch value := i.(type) { + case []string: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = Int32(v) } - return array + case []int: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []int8: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []int16: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []int32: + array = value + case []int64: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []uint: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []uint8: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []uint16: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []uint32: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []uint64: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = int32(v) + } + case []bool: + array = make([]int32, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = Int32(v) + } + case []float64: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = Int32(v) + } + case []interface{}: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = Int32(v) + } + case [][]byte: + array = make([]int32, len(value)) + for k, v := range value { + array[k] = Int32(v) + } + default: + if v, ok := i.(apiInts); ok { + return Int32s(v.Ints()) + } + if v, ok := i.(apiInterfaces); ok { + return Int32s(v.Interfaces()) + } + return []int32{Int32(i)} } + return array } // Int64s converts to []int64. @@ -218,93 +231,97 @@ func Int64s(i interface{}) []int64 { if i == nil { return nil } - if r, ok := i.([]int64); ok { - return r - } else { - var array []int64 - switch value := i.(type) { - case []string: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = Int64(v) - } - case []int: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []int8: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []int16: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []int32: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []uint: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []uint8: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []uint16: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []uint32: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []uint64: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = int64(v) - } - case []bool: - array = make([]int64, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = Int64(v) - } - case []float64: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = Int64(v) - } - case []interface{}: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = Int64(v) - } - case [][]byte: - array = make([]int64, len(value)) - for k, v := range value { - array[k] = Int64(v) - } - default: - return []int64{Int64(i)} + var array []int64 + switch value := i.(type) { + case []string: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = Int64(v) } - return array + case []int: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []int8: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []int16: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []int32: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []int64: + array = value + case []uint: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []uint8: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []uint16: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []uint32: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []uint64: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = int64(v) + } + case []bool: + array = make([]int64, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = Int64(v) + } + case []float64: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = Int64(v) + } + case []interface{}: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = Int64(v) + } + case [][]byte: + array = make([]int64, len(value)) + for k, v := range value { + array[k] = Int64(v) + } + default: + if v, ok := i.(apiInts); ok { + return Int64s(v.Ints()) + } + if v, ok := i.(apiInterfaces); ok { + return Int64s(v.Interfaces()) + } + return []int64{Int64(i)} } + return array } diff --git a/util/gconv/gconv_slice_str.go b/util/gconv/gconv_slice_str.go index 7f847f46b..b830fe61b 100644 --- a/util/gconv/gconv_slice_str.go +++ b/util/gconv/gconv_slice_str.go @@ -6,6 +6,11 @@ package gconv +// apiStrings is used for type assert api for Strings. +type apiStrings interface { + Strings() []string +} + // SliceStr is alias of Strings. func SliceStr(i interface{}) []string { return Strings(i) @@ -16,89 +21,93 @@ func Strings(i interface{}) []string { if i == nil { return nil } - if r, ok := i.([]string); ok { - return r - } else { - var array []string - switch value := i.(type) { - case []int: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []int8: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []int16: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []int32: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []int64: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []uint: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []uint8: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []uint16: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []uint32: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []uint64: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []bool: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []float32: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []float64: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case []interface{}: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - case [][]byte: - array = make([]string, len(value)) - for k, v := range value { - array[k] = String(v) - } - default: - return []string{String(i)} + var array []string + switch value := i.(type) { + case []int: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) } - return array + case []int8: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []int16: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []int32: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []int64: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []uint: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []uint8: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []uint16: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []uint32: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []uint64: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []bool: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []float32: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []float64: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []interface{}: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + case []string: + array = value + case [][]byte: + array = make([]string, len(value)) + for k, v := range value { + array[k] = String(v) + } + default: + if v, ok := i.(apiStrings); ok { + return v.Strings() + } + if v, ok := i.(apiInterfaces); ok { + return Strings(v.Interfaces()) + } + return []string{String(i)} } + return array } diff --git a/util/gconv/gconv_slice_uint.go b/util/gconv/gconv_slice_uint.go index 808b3fc23..f9095e618 100644 --- a/util/gconv/gconv_slice_uint.go +++ b/util/gconv/gconv_slice_uint.go @@ -6,6 +6,11 @@ package gconv +// apiUints is used for type assert api for Uints. +type apiUints interface { + Uints() []uint +} + // SliceUint is alias of Uints. func SliceUint(i interface{}) []uint { return Uints(i) @@ -26,90 +31,95 @@ func Uints(i interface{}) []uint { if i == nil { return nil } - if r, ok := i.([]uint); ok { - return r - } else { - var array []uint - switch value := i.(type) { - case []string: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = Uint(v) - } - case []int8: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []int16: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []int32: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []int64: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []uint8: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []uint16: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []uint32: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []uint64: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = uint(v) - } - case []bool: - array = make([]uint, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = Uint(v) - } - case []float64: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = Uint(v) - } - case []interface{}: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = Uint(v) - } - case [][]byte: - array = make([]uint, len(value)) - for k, v := range value { - array[k] = Uint(v) - } - default: - return []uint{Uint(i)} + + var array []uint + switch value := i.(type) { + case []string: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = Uint(v) } - return array + case []int8: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []int16: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []int32: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []int64: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []uint: + array = value + case []uint8: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []uint16: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []uint32: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []uint64: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = uint(v) + } + case []bool: + array = make([]uint, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = Uint(v) + } + case []float64: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = Uint(v) + } + case []interface{}: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = Uint(v) + } + case [][]byte: + array = make([]uint, len(value)) + for k, v := range value { + array[k] = Uint(v) + } + default: + if v, ok := i.(apiUints); ok { + return v.Uints() + } + if v, ok := i.(apiInterfaces); ok { + return Uints(v.Interfaces()) + } + return []uint{Uint(i)} } + return array } // Uint32s converts to []uint32. @@ -117,90 +127,94 @@ func Uint32s(i interface{}) []uint32 { if i == nil { return nil } - if r, ok := i.([]uint32); ok { - return r - } else { - var array []uint32 - switch value := i.(type) { - case []string: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = Uint32(v) - } - case []int8: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []int16: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []int32: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []int64: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []uint: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []uint8: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []uint16: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []uint64: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = uint32(v) - } - case []bool: - array = make([]uint32, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = Uint32(v) - } - case []float64: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = Uint32(v) - } - case []interface{}: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = Uint32(v) - } - case [][]byte: - array = make([]uint32, len(value)) - for k, v := range value { - array[k] = Uint32(v) - } - default: - return []uint32{Uint32(i)} + var array []uint32 + switch value := i.(type) { + case []string: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = Uint32(v) } - return array + case []int8: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []int16: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []int32: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []int64: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []uint: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []uint8: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []uint16: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []uint32: + array = value + case []uint64: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = uint32(v) + } + case []bool: + array = make([]uint32, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = Uint32(v) + } + case []float64: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = Uint32(v) + } + case []interface{}: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = Uint32(v) + } + case [][]byte: + array = make([]uint32, len(value)) + for k, v := range value { + array[k] = Uint32(v) + } + default: + if v, ok := i.(apiUints); ok { + return Uint32s(v.Uints()) + } + if v, ok := i.(apiInterfaces); ok { + return Uint32s(v.Interfaces()) + } + return []uint32{Uint32(i)} } + return array } // Uint64s converts to []uint64. @@ -208,88 +222,92 @@ func Uint64s(i interface{}) []uint64 { if i == nil { return nil } - if r, ok := i.([]uint64); ok { - return r - } else { - var array []uint64 - switch value := i.(type) { - case []string: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = Uint64(v) - } - case []int8: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []int16: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []int32: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []int64: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []uint: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []uint8: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []uint16: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []uint32: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = uint64(v) - } - case []bool: - array = make([]uint64, len(value)) - for k, v := range value { - if v { - array[k] = 1 - } else { - array[k] = 0 - } - } - case []float32: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = Uint64(v) - } - case []float64: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = Uint64(v) - } - case []interface{}: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = Uint64(v) - } - case [][]byte: - array = make([]uint64, len(value)) - for k, v := range value { - array[k] = Uint64(v) - } - default: - return []uint64{Uint64(i)} + var array []uint64 + switch value := i.(type) { + case []string: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = Uint64(v) } - return array + case []int8: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []int16: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []int32: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []int64: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []uint: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []uint8: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []uint16: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []uint32: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = uint64(v) + } + case []uint64: + array = value + case []bool: + array = make([]uint64, len(value)) + for k, v := range value { + if v { + array[k] = 1 + } else { + array[k] = 0 + } + } + case []float32: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = Uint64(v) + } + case []float64: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = Uint64(v) + } + case []interface{}: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = Uint64(v) + } + case [][]byte: + array = make([]uint64, len(value)) + for k, v := range value { + array[k] = Uint64(v) + } + default: + if v, ok := i.(apiUints); ok { + return Uint64s(v.Uints()) + } + if v, ok := i.(apiInterfaces); ok { + return Uint64s(v.Interfaces()) + } + return []uint64{Uint64(i)} } + return array } diff --git a/util/gconv/gconv_z_unit_map_test.go b/util/gconv/gconv_z_unit_map_test.go index 5f4176c0e..d9d7a9aff 100644 --- a/util/gconv/gconv_z_unit_map_test.go +++ b/util/gconv/gconv_z_unit_map_test.go @@ -7,6 +7,7 @@ package gconv_test import ( + "github.com/gogf/gf/util/gutil" "testing" "github.com/gogf/gf/frame/g" @@ -54,7 +55,7 @@ func Test_Map_Slice(t *testing.T) { }) } -func Test_Map_StructWithGconvTag(t *testing.T) { +func Test_Map_StructWithGConvTag(t *testing.T) { gtest.C(t, func(t *gtest.T) { type User struct { Uid int @@ -179,7 +180,7 @@ func Test_Map_PrivateAttribute(t *testing.T) { }) } -func Test_Map_StructInherit(t *testing.T) { +func Test_MapDeep1(t *testing.T) { type Ids struct { Id int `c:"id"` Uid int `c:"uid"` @@ -216,6 +217,84 @@ func Test_Map_StructInherit(t *testing.T) { }) } +func Test_MapDeep2(t *testing.T) { + type A struct { + F string + G string + } + + type B struct { + A + H string + } + + type C struct { + A A + F string + } + + type D struct { + I A + F string + } + + gtest.C(t, func(t *gtest.T) { + b := new(B) + c := new(C) + d := new(D) + mb := gconv.MapDeep(b) + mc := gconv.MapDeep(c) + md := gconv.MapDeep(d) + t.Assert(gutil.MapContains(mb, "F"), true) + t.Assert(gutil.MapContains(mb, "G"), true) + t.Assert(gutil.MapContains(mb, "H"), true) + t.Assert(gutil.MapContains(mc, "A"), true) + t.Assert(gutil.MapContains(mc, "F"), true) + t.Assert(gutil.MapContains(mc, "G"), false) + t.Assert(gutil.MapContains(md, "F"), true) + t.Assert(gutil.MapContains(md, "I"), true) + t.Assert(gutil.MapContains(md, "H"), false) + t.Assert(gutil.MapContains(md, "G"), false) + }) +} + +func Test_MapDeepWithAttributeTag(t *testing.T) { + type Ids struct { + Id int `c:"id"` + Uid int `c:"uid"` + } + type Base struct { + Ids `json:"ids"` + CreateTime string `c:"create_time"` + } + type User struct { + Base `json:"base"` + Passport string `c:"passport"` + Password string `c:"password"` + Nickname string `c:"nickname"` + } + gtest.C(t, func(t *gtest.T) { + user := new(User) + user.Id = 100 + user.Nickname = "john" + user.CreateTime = "2019" + m := gconv.Map(user) + t.Assert(m["id"], "") + t.Assert(m["nickname"], user.Nickname) + t.Assert(m["create_time"], "") + }) + gtest.C(t, func(t *gtest.T) { + user := new(User) + user.Id = 100 + user.Nickname = "john" + user.CreateTime = "2019" + m := gconv.MapDeep(user) + t.Assert(m["base"].(map[string]interface{})["ids"].(map[string]interface{})["id"], user.Id) + t.Assert(m["nickname"], user.Nickname) + t.Assert(m["base"].(map[string]interface{})["create_time"], user.CreateTime) + }) +} + func Test_MapToMap(t *testing.T) { type User struct { Id int @@ -454,3 +533,56 @@ func Test_MapToMapsDeep(t *testing.T) { t.Assert(m["200"][1].Name, "jim") }) } + +func Test_MapToMapsDeepWithTag(t *testing.T) { + type Ids struct { + Id int + Uid int + } + type Base struct { + Ids `json:"ids"` + Time string + } + type User struct { + Base `json:"base"` + Name string + } + params := g.MapIntAny{ + 100: g.Slice{ + g.Map{"id": 1, "name": "john"}, + g.Map{"id": 2, "name": "smith"}, + }, + 200: g.Slice{ + g.Map{"id": 3, "name": "green"}, + g.Map{"id": 4, "name": "jim"}, + }, + } + gtest.C(t, func(t *gtest.T) { + m := make(map[string][]*User) + err := gconv.MapToMaps(params, &m) + t.Assert(err, nil) + t.Assert(len(m), 2) + t.Assert(m["100"][0].Id, 0) + t.Assert(m["100"][1].Id, 0) + t.Assert(m["100"][0].Name, "john") + t.Assert(m["100"][1].Name, "smith") + t.Assert(m["200"][0].Id, 0) + t.Assert(m["200"][1].Id, 0) + t.Assert(m["200"][0].Name, "green") + t.Assert(m["200"][1].Name, "jim") + }) + gtest.C(t, func(t *gtest.T) { + m := make(map[string][]*User) + err := gconv.MapToMapsDeep(params, &m) + t.Assert(err, nil) + t.Assert(len(m), 2) + t.Assert(m["100"][0].Id, 1) + t.Assert(m["100"][1].Id, 2) + t.Assert(m["100"][0].Name, "john") + t.Assert(m["100"][1].Name, "smith") + t.Assert(m["200"][0].Id, 3) + t.Assert(m["200"][1].Id, 4) + t.Assert(m["200"][0].Name, "green") + t.Assert(m["200"][1].Name, "jim") + }) +} diff --git a/util/gutil/gutil.go b/util/gutil/gutil.go index 32487fd6a..f3942747e 100644 --- a/util/gutil/gutil.go +++ b/util/gutil/gutil.go @@ -16,21 +16,13 @@ func Throw(exception interface{}) { panic(exception) } -// TryCatch implements try...catch... logistics. +// TryCatch implements try...catch... logistics using internal panic...recover. func TryCatch(try func(), catch ...func(exception interface{})) { - if len(catch) > 0 { - // If is given, it's used to handle the exception. - defer func() { - if e := recover(); e != nil { - catch[0](e) - } - }() - } else { - // If no function passed, it filters the exception. - defer func() { - recover() - }() - } + defer func() { + if e := recover(); e != nil && len(catch) > 0 { + catch[0](e) + } + }() try() } diff --git a/util/gutil/gutil_map.go b/util/gutil/gutil_map.go index 960b7282b..bc014dc1b 100644 --- a/util/gutil/gutil_map.go +++ b/util/gutil/gutil_map.go @@ -7,18 +7,12 @@ package gutil import ( - "regexp" - "strings" + "github.com/gogf/gf/internal/utils" ) -var ( - // replaceCharReg is the regular expression object for replacing chars in map keys. - replaceCharReg, _ = regexp.Compile(`[\-\.\_\s]+`) -) - -// CopyMap does a shallow copy from map to for most commonly used map type +// MapCopy does a shallow copy from map to for most commonly used map type // map[string]interface{}. -func CopyMap(data map[string]interface{}) (copy map[string]interface{}) { +func MapCopy(data map[string]interface{}) (copy map[string]interface{}) { copy = make(map[string]interface{}, len(data)) for k, v := range data { copy[k] = v @@ -26,24 +20,67 @@ func CopyMap(data map[string]interface{}) (copy map[string]interface{}) { return } +// MapContains checks whether map contains . +func MapContains(data map[string]interface{}, key string) (ok bool) { + _, ok = data[key] + return +} + +// MapDelete deletes all from map . +func MapDelete(data map[string]interface{}, key ...string) { + if data == nil { + return + } + for _, v := range key { + delete(data, v) + } +} + +// MapMerge merges all map from to map . +func MapMerge(dst map[string]interface{}, src ...map[string]interface{}) { + if dst == nil { + return + } + for _, m := range src { + for k, v := range m { + dst[k] = v + } + } +} + +// MapMergeCopy creates and returns a new map which merges all map from . +func MapMergeCopy(src ...map[string]interface{}) (copy map[string]interface{}) { + copy = make(map[string]interface{}) + for _, m := range src { + for k, v := range m { + copy[k] = v + } + } + return +} + // MapPossibleItemByKey tries to find the possible key-value pair for given key with or without // cases or chars '-'/'_'/'.'/' '. // // Note that this function might be of low performance. -func MapPossibleItemByKey(data map[string]interface{}, key string) (string, interface{}) { +func MapPossibleItemByKey(data map[string]interface{}, key string) (foundKey string, foundValue interface{}) { if v, ok := data[key]; ok { return key, v } - replacedKey := replaceCharReg.ReplaceAllString(key, "") - if v, ok := data[replacedKey]; ok { - return replacedKey, v - } - // Loop for check. + // Loop checking. for k, v := range data { - // Remove all special chars and compare with case insensitive. - if strings.EqualFold(replaceCharReg.ReplaceAllString(k, ""), replacedKey) { + if utils.EqualFoldWithoutChars(k, key) { return k, v } } return "", nil } + +// MapContainsPossibleKey checks if the given is contained in given map . +// It checks the key with or without cases or chars '-'/'_'/'.'/' '. +func MapContainsPossibleKey(data map[string]interface{}, key string) bool { + if k, _ := MapPossibleItemByKey(data, key); k != "" { + return true + } + return false +} diff --git a/util/gutil/gutil_z_bench_test.go b/util/gutil/gutil_z_bench_test.go index c68f38421..2276c021f 100644 --- a/util/gutil/gutil_z_bench_test.go +++ b/util/gutil/gutil_z_bench_test.go @@ -12,6 +12,15 @@ import ( "testing" ) +var ( + m1 = map[string]interface{}{ + "k1": "v1", + } + m2 = map[string]interface{}{ + "k2": "v2", + } +) + func Benchmark_TryCatch(b *testing.B) { for i := 0; i < b.N; i++ { TryCatch(func() { @@ -21,3 +30,9 @@ func Benchmark_TryCatch(b *testing.B) { }) } } + +func Benchmark_MapMergeCopy(b *testing.B) { + for i := 0; i < b.N; i++ { + MapMergeCopy(m1, m2) + } +} diff --git a/util/gutil/gutil_z_unit_map_test.go b/util/gutil/gutil_z_unit_map_test.go new file mode 100755 index 000000000..e71b032ce --- /dev/null +++ b/util/gutil/gutil_z_unit_map_test.go @@ -0,0 +1,120 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gutil_test + +import ( + "github.com/gogf/gf/frame/g" + "testing" + + "github.com/gogf/gf/test/gtest" + "github.com/gogf/gf/util/gutil" +) + +func Test_MapCopy(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m1 := g.Map{ + "k1": "v1", + } + m2 := gutil.MapCopy(m1) + m2["k2"] = "v2" + + t.Assert(m1["k1"], "v1") + t.Assert(m1["k2"], nil) + t.Assert(m2["k1"], "v1") + t.Assert(m2["k2"], "v2") + }) +} + +func Test_MapContains(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m1 := g.Map{ + "k1": "v1", + } + t.Assert(gutil.MapContains(m1, "k1"), true) + t.Assert(gutil.MapContains(m1, "K1"), false) + t.Assert(gutil.MapContains(m1, "k2"), false) + }) +} + +func Test_MapMerge(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m1 := g.Map{ + "k1": "v1", + } + m2 := g.Map{ + "k2": "v2", + } + m3 := g.Map{ + "k3": "v3", + } + gutil.MapMerge(m1, m2, m3, nil) + t.Assert(m1["k1"], "v1") + t.Assert(m1["k2"], "v2") + t.Assert(m1["k3"], "v3") + t.Assert(m2["k1"], nil) + t.Assert(m3["k1"], nil) + }) +} + +func Test_MapMergeCopy(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m1 := g.Map{ + "k1": "v1", + } + m2 := g.Map{ + "k2": "v2", + } + m3 := g.Map{ + "k3": "v3", + } + m := gutil.MapMergeCopy(m1, m2, m3, nil) + t.Assert(m["k1"], "v1") + t.Assert(m["k2"], "v2") + t.Assert(m["k3"], "v3") + t.Assert(m1["k1"], "v1") + t.Assert(m1["k2"], nil) + t.Assert(m2["k1"], nil) + t.Assert(m3["k1"], nil) + }) +} + +func Test_MapPossibleItemByKey(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m := g.Map{ + "name": "guo", + "NickName": "john", + } + k, v := gutil.MapPossibleItemByKey(m, "NAME") + t.Assert(k, "name") + t.Assert(v, "guo") + + k, v = gutil.MapPossibleItemByKey(m, "nick name") + t.Assert(k, "NickName") + t.Assert(v, "john") + + k, v = gutil.MapPossibleItemByKey(m, "none") + t.Assert(k, "") + t.Assert(v, nil) + }) +} + +func Test_MapContainsPossibleKey(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + m := g.Map{ + "name": "guo", + "NickName": "john", + } + t.Assert(gutil.MapContainsPossibleKey(m, "name"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "NAME"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "nickname"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "nick name"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "nick_name"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "nick-name"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "nick.name"), true) + t.Assert(gutil.MapContainsPossibleKey(m, "none"), false) + }) +} diff --git a/util/gutil/gutil_z_unit_test.go b/util/gutil/gutil_z_unit_test.go index c9b4a0ec5..dbd6d0469 100755 --- a/util/gutil/gutil_z_unit_test.go +++ b/util/gutil/gutil_z_unit_test.go @@ -19,18 +19,9 @@ func Test_Dump(t *testing.T) { 100: 100, }) }) - - gtest.C(t, func(t *gtest.T) { - gutil.Dump(map[string]interface{}{"": func() {}}) - }) - - gtest.C(t, func(t *gtest.T) { - gutil.Dump([]byte("gutil Dump test")) - }) } func Test_TryCatch(t *testing.T) { - gtest.C(t, func(t *gtest.T) { gutil.TryCatch(func() { panic("gutil TryCatch test") diff --git a/util/gvalid/gvalid_check_struct.go b/util/gvalid/gvalid_check_struct.go index 6bc7e7b5a..a1ddca186 100644 --- a/util/gvalid/gvalid_check_struct.go +++ b/util/gvalid/gvalid_check_struct.go @@ -20,16 +20,15 @@ var ( // 校验struct对象属性,object参数也可以是一个指向对象的指针,返回值同CheckMap方法。 // struct的数据校验结果信息是顺序的。 -func CheckStruct(object interface{}, rules interface{}, msgs ...CustomMsg) *Error { - params := make(map[string]interface{}) - checkRules := make(map[string]string) - customMsgs := make(CustomMsg) - // 字段别名记录,用于msgs覆盖struct tag的 - fieldAliases := make(map[string]string) - // 返回的顺序规则 - errorRules := make([]string, 0) - // 返回的校验错误 - errorMaps := make(ErrorMap) +func CheckStruct(object interface{}, rules interface{}, messages ...CustomMsg) *Error { + var ( + params = make(map[string]interface{}) + checkRules = make(map[string]string) + customMessage = make(CustomMsg) + fieldAliases = make(map[string]string) // Alias names for overwriting struct tag names. + errorRules = make([]string, 0) // Sequence rules. + errorMaps = make(ErrorMap) // Returned error + ) // 解析rules参数 switch v := rules.(type) { // 支持校验错误顺序: []sequence tag @@ -52,10 +51,10 @@ func CheckStruct(object interface{}, rules interface{}, msgs ...CustomMsg) *Erro continue } array := strings.Split(v, ":") - if _, ok := customMsgs[name]; !ok { - customMsgs[name] = make(map[string]string) + if _, ok := customMessage[name]; !ok { + customMessage[name] = make(map[string]string) } - customMsgs[name].(map[string]string)[strings.TrimSpace(array[0])] = strings.TrimSpace(msgArray[k]) + customMessage[name].(map[string]string)[strings.TrimSpace(array[0])] = strings.TrimSpace(msgArray[k]) } } checkRules[name] = rule @@ -114,22 +113,22 @@ func CheckStruct(object interface{}, rules interface{}, msgs ...CustomMsg) *Erro continue } array := strings.Split(v, ":") - if _, ok := customMsgs[name]; !ok { - customMsgs[name] = make(map[string]string) + if _, ok := customMessage[name]; !ok { + customMessage[name] = make(map[string]string) } - customMsgs[name].(map[string]string)[strings.TrimSpace(array[0])] = strings.TrimSpace(msgArray[k]) + customMessage[name].(map[string]string)[strings.TrimSpace(array[0])] = strings.TrimSpace(msgArray[k]) } } } // 自定义错误消息,非必须参数,优先级比rules参数中以及struct tag中定义的错误消息更高 - if len(msgs) > 0 && len(msgs[0]) > 0 { - for k, v := range msgs[0] { + if len(messages) > 0 && len(messages[0]) > 0 { + for k, v := range messages[0] { if a, ok := fieldAliases[k]; ok { // 属性的别名存在时,覆盖别名的错误信息 - customMsgs[a] = v + customMessage[a] = v } else { - customMsgs[k] = v + customMessage[k] = v } } } @@ -144,7 +143,7 @@ func CheckStruct(object interface{}, rules interface{}, msgs ...CustomMsg) *Erro if v, ok := params[key]; ok { value = v } - if e := Check(value, rule, customMsgs[key], params); e != nil { + if e := Check(value, rule, customMessage[key], params); e != nil { _, item := e.FirstItem() // 如果值为nil|"",并且不需要require*验证时,其他验证失效 if value == nil || gconv.String(value) == "" { diff --git a/version.go b/version.go index eea388e19..e7c69e151 100644 --- a/version.go +++ b/version.go @@ -1,4 +1,4 @@ package gf -const VERSION = "v1.12.1" +const VERSION = "v1.12.2" const AUTHORS = "john"