mirror of
https://gitee.com/johng/gf
synced 2026-06-08 10:37:44 +08:00
Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d604d198ab | |||
| 36791d2f48 | |||
| 08f9cffed9 | |||
| 783c0ba846 | |||
| 7ad4f61564 | |||
| adf06a2b0d | |||
| d6aa2b2512 | |||
| 0a8af94610 | |||
| 2c27c0f58a | |||
| 4172eae87e | |||
| 26f2c61068 | |||
| f97bed2607 | |||
| 8ef7155c70 | |||
| 2c6e8f88fb | |||
| 25068b1e83 | |||
| 1f36eb3a9a | |||
| a9ed577d05 | |||
| 782d614082 | |||
| 0629c00b07 | |||
| b90d5bb205 | |||
| cbc824c80a | |||
| 0c9be40b86 | |||
| c96abd706d | |||
| 0ae5872783 | |||
| 2cff10e0d2 | |||
| cab78f557d | |||
| 04353aa1a5 | |||
| 35121a66e9 | |||
| e726ed2c19 | |||
| 503446afc7 | |||
| 2063f662d3 | |||
| d7381399aa | |||
| d05b497cdb | |||
| ef919be587 | |||
| fff31e0f4f | |||
| cdd6fc7c1e | |||
| 74bc36a2dc | |||
| 48328ae52c | |||
| a86f4f8e23 | |||
| 0a1e048268 | |||
| 6fc5efd6ba | |||
| 2d795b593d | |||
| 20628ec75c | |||
| 10d1ccb009 | |||
| fcc37c9581 | |||
| 43cd391543 | |||
| 18d2df33f7 | |||
| a85daa5617 | |||
| 48dc4ce3e2 | |||
| d07bac89a0 | |||
| 5d32ad6bc4 | |||
| 397b0a3e7e | |||
| 259961632d | |||
| cb1d6382ec | |||
| 8714a69a13 | |||
| 3ae0ea2de7 | |||
| 1879a9f4c7 | |||
| 3938717b04 | |||
| 1208b688f1 | |||
| 0ad7ee5a32 | |||
| 7a4e68e6b9 | |||
| 71222b247f | |||
| 95db811943 | |||
| 2dbc817132 | |||
| 7a8bd96edc | |||
| c5e9686a95 | |||
| c914edf616 | |||
| 656bfcb6bd | |||
| 7434dfe6fa | |||
| e67aa63a50 | |||
| d5e46f2b42 | |||
| 09e6f10b60 |
33
.travis.yml
Normal file
33
.travis.yml
Normal file
@ -0,0 +1,33 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- "1.11.x"
|
||||
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- develop
|
||||
|
||||
env:
|
||||
- GITEE_GF=$GOPATH/src/gitee.com/johng/gf GO111MODULE=on
|
||||
|
||||
services:
|
||||
- mysql
|
||||
|
||||
before_install:
|
||||
- pwd
|
||||
|
||||
install:
|
||||
- pwd
|
||||
- mkdir -p $GITEE_GF
|
||||
- cp * $GITEE_GF -R
|
||||
- cd $GITEE_GF
|
||||
|
||||
script:
|
||||
- cd g && go test -v ./... -race -coverprofile=coverage.txt -covermode=atomic
|
||||
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
|
||||
|
||||
66
README.MD
66
README.MD
@ -1,6 +1,18 @@
|
||||
<div align=center>
|
||||
<img src="https://gfer.me/cover.png" width="150"/>
|
||||
</div>
|
||||
# GoFrame
|
||||
<img align="right" height="150px" src="https://gfer.me/cover.png">
|
||||
|
||||
[](https://godoc.org/github.com/johng-cn/gf)
|
||||
[](https://travis-ci.org/johng-cn/gf)
|
||||
[](https://goreportcard.com/report/github.com/johng-cn/gf)
|
||||
[](https://gfer.me)
|
||||
[](https://github.com/johng-cn/gf)
|
||||
[](https://github.com/johng-cn/gf)
|
||||
[](https://github.com/johng-cn/gf/releases)
|
||||
|
||||
<!--
|
||||
[](https://codecov.io/gh/johng-cn/gf)
|
||||
[](https://www.codetriage.com/johng-cn/gf)
|
||||
-->
|
||||
|
||||
`GF(GoFrame)` is a modular, lightweight, loosely coupled, high performance application development framework written in Go. Supporting graceful server, hot updates, multi-domain, multi-port, multi-service, HTTP/HTTPS, dynamic/hook routing and many more features. Providing a series of core components and dozens of practical modules.
|
||||
|
||||
@ -8,6 +20,11 @@
|
||||
```
|
||||
go get -u gitee.com/johng/gf
|
||||
```
|
||||
or use `go.mod`
|
||||
```
|
||||
require gitee.com/johng/gf latest
|
||||
```
|
||||
|
||||
# Limitation
|
||||
```
|
||||
golang version >= 1.9.2
|
||||
@ -41,6 +58,47 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
[View More..](https://gfer.me/start/index)
|
||||
|
||||
|
||||
# License
|
||||
|
||||
GF is licensed under the [MIT License](LICENSE), 100% free and open-source.
|
||||
`GF` is licensed under the [MIT License](LICENSE), 100% free and open-source, forever.
|
||||
|
||||
# Contributors(TOP 10)
|
||||
|
||||
<a href="https://gitee.com/johng" target="_blank" title="John"><img src="https://gitee.com/uploads/27/1309327_johng.png?1530630243" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/wenzi1" target="_blank" title="蚊子"><img src="https://images.gitee.com/uploads/22/1923122_wenzi1.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/zseeker" target="_blank" title="zseeker"><img src="https://gfer.me/images/contributors/zseeker.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/ymrjqyy" target="_blank" title="一墨染尽青衣颜"><img src="https://images.gitee.com/uploads/27/876827_ymrjqyy.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://github.com/chenyang351" target="_blank" title="chenyang351"><img src="https://avatars1.githubusercontent.com/u/30063958?s=60&v=4" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/wxkj" target="_blank" title="wxkj"><img src="https://gitee.com/uploads/56/91356_wxkj.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://github.com/wxkj001" target="_blank" title="3wxkj001
|
||||
"><img src="https://avatars0.githubusercontent.com/u/7794279?s=60&v=4" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/zhangjinfu" target="_blank" title="张金富"><img src="https://images.gitee.com/uploads/63/356163_zhangjinfu.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/garfieldkwong" target="_blank" title="GarfieldKwong"><img src="https://gfer.me/images/contributors/garfieldkwong.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/qq1054000800" target="_blank" title="hello"><img src="https://gitee.com/uploads/9/2209_qq1054000800.jpg" width="60" align="left"></a>
|
||||
|
||||
<br /><br /><br />
|
||||
|
||||
# Donators
|
||||
|
||||
<a href="https://gitee.com/zfan_codes" target="_blank" title="范钟"><img src="https://images.gitee.com/uploads/32/2044832_zfan_codes.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/hailaz" target="_blank" title="HaiLaz"><img src="https://gitee.com/uploads/87/1273187_hailaz.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/mg91" target="_blank" title="mg91"><img src="https://images.gitee.com/uploads/30/1410930_mg91.png" width="60" align="left"></a>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
67
README_ZH.MD
67
README_ZH.MD
@ -1,18 +1,43 @@
|
||||
<div align=center>
|
||||
<img src="https://gfer.me/cover.png" width="150"/>
|
||||
</div>
|
||||
# GoFrame
|
||||
<img align="right" height="150px" src="https://gfer.me/cover.png">
|
||||
|
||||
[](https://godoc.org/github.com/johng-cn/gf)
|
||||
[](https://travis-ci.org/johng-cn/gf)
|
||||
[](https://goreportcard.com/report/github.com/johng-cn/gf)
|
||||
[](https://gfer.me)
|
||||
[](https://github.com/johng-cn/gf)
|
||||
[](https://github.com/johng-cn/gf)
|
||||
[](https://github.com/johng-cn/gf/releases)
|
||||
|
||||
<!--
|
||||
[](https://codecov.io/gh/johng-cn/gf)
|
||||
[](https://www.codetriage.com/johng-cn/gf)
|
||||
-->
|
||||
|
||||
`GF(Go Frame)`是一款模块化、松耦合、轻量级、高性能的Go应用开发框架。支持热重启、热更新、多域名、多端口、多服务、HTTP/HTTPS、动态路由等特性
|
||||
,并提供了Web服务开发的系列核心组件,如:Router、Cookie、Session、服务注册、配置管理、模板引擎、数据校验、分页管理、数据库ORM等等等等,
|
||||
并且提供了数十个内置核心开发模块集,如:缓存、日志、时间、命令行、二进制、文件锁、内存锁、对象池、连接池、数据编码、进程管理、进程通信、文件监控、定时任务、TCP/UDP组件、
|
||||
并发安全容器等等等等等等。
|
||||
|
||||
|
||||
# 特点
|
||||
* 模块化、松耦合设计;
|
||||
* 丰富实用的开发模块;
|
||||
* 详尽的开发文档及示例;
|
||||
* 完善的本地中文化支持;
|
||||
* 致力于项目的通用方案;
|
||||
* 更适合企业及团队使用;
|
||||
* 更多请查阅文档及源码;
|
||||
|
||||
# 安装
|
||||
```html
|
||||
go get -u gitee.com/johng/gf
|
||||
```
|
||||
|
||||
或者
|
||||
`go.mod`
|
||||
```
|
||||
require gitee.com/johng/gf latest
|
||||
```
|
||||
# 限制
|
||||
```shell
|
||||
golang版本 >= 1.9.2
|
||||
@ -24,6 +49,7 @@ golang版本 >= 1.9.2
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
# 文档
|
||||
[https://gfer.me](https://gfer.me)
|
||||
|
||||
@ -44,4 +70,35 @@ func main() {
|
||||
})
|
||||
s.Run()
|
||||
}
|
||||
```
|
||||
```
|
||||
|
||||
[更多..](https://gfer.me/start/index)
|
||||
|
||||
|
||||
# 协议
|
||||
|
||||
`GF` 使用非常友好的 [MIT](LICENSE) 开源协议进行发布,永久`100%`开源免费。
|
||||
|
||||
# 贡献者(TOP 10)
|
||||
|
||||
<a href="https://gitee.com/johng" target="_blank" title="John"><img src="https://gitee.com/uploads/27/1309327_johng.png" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/wenzi1" target="_blank" title="蚊子"><img src="https://images.gitee.com/uploads/22/1923122_wenzi1.png" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/zseeker" target="_blank" title="zseeker"><img src="https://gfer.me/images/contributors/zseeker.png" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/ymrjqyy" target="_blank" title="一墨染尽青衣颜"><img src="https://images.gitee.com/uploads/27/876827_ymrjqyy.png" width="60" align="left"></a>
|
||||
<a href="https://github.com/chenyang351" target="_blank" title="chenyang351"><img src="https://avatars1.githubusercontent.com/u/30063958?s=60&v=4" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/wxkj" target="_blank" title="wxkj"><img src="https://gitee.com/uploads/56/91356_wxkj.png" width="60" align="left"></a>
|
||||
<a href="https://github.com/wxkj001" target="_blank" title="3wxkj001
|
||||
"><img src="https://avatars0.githubusercontent.com/u/7794279?s=60&v=4" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/zhangjinfu" target="_blank" title="张金富"><img src="https://images.gitee.com/uploads/63/356163_zhangjinfu.png" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/garfieldkwong" target="_blank" title="GarfieldKwong"><img src="https://gfer.me/images/contributors/garfieldkwong.png" width="60" align="left"></a>
|
||||
<a href="https://gitee.com/qq1054000800" target="_blank" title="hello"><img src="https://gitee.com/uploads/9/2209_qq1054000800.jpg" width="60" align="left"></a>
|
||||
|
||||
<br /><br /><br />
|
||||
|
||||
# 捐赠者
|
||||
|
||||
<a href="https://gitee.com/zfan_codes" target="_blank" title="范钟"><img src="https://images.gitee.com/uploads/32/2044832_zfan_codes.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/hailaz" target="_blank" title="HaiLaz"><img src="https://gitee.com/uploads/87/1273187_hailaz.png" width="60" align="left"></a>
|
||||
|
||||
<a href="https://gitee.com/mg91" target="_blank" title="mg91"><img src="https://images.gitee.com/uploads/30/1410930_mg91.png" width="60" align="left"></a>
|
||||
37
RELEASE.MD
37
RELEASE.MD
@ -1,3 +1,40 @@
|
||||
# `v1.3.8` (2018-12-26)
|
||||
|
||||
## 新特性
|
||||
1. 对`gform`完成重构,以提高扩展性,并修复部分细节问题、完善单元测试用例([https://gfer.me/database/orm/index](https://gfer.me/database/orm/index));
|
||||
1. `WebServer`路由注册新增分组路由特性([https://gfer.me/net/ghttp/group](https://gfer.me/net/ghttp/group));
|
||||
1. `WebServer`新增`Rewrite`路由重写特性([https://gfer.me/net/ghttp/static](https://gfer.me/net/ghttp/static));
|
||||
1. 增加框架运行时对开发环境的自动识别;
|
||||
1. 增加了`Travis CI`自动化构建/测试;
|
||||
|
||||
## 新功能
|
||||
1. 改进`WebServer`静态文件服务功能,增加`SetStaticPath`/`AddStaticPath`方法([https://gfer.me/net/ghttp/static](https://gfer.me/net/ghttp/static));
|
||||
1. `gform`新增`Filter`链式操作方法,用于过滤参数中的非表字段键值对([https://gfer.me/database/orm/linkop](https://gfer.me/database/orm/linkop));
|
||||
1. `gcache`新增`Data`方法,用以获取所有的缓存数据项;
|
||||
1. `gredis`增加`GetConn`方法获取原生redis连接对象;
|
||||
|
||||
## 功能改进
|
||||
1. 改进`gform`的`Where`方法,支持`slice`类型的参数,并更方便地支持`in`操作查询([https://gfer.me/database/orm/linkop](https://gfer.me/database/orm/linkop));
|
||||
1. 改进`gproc`进程间通信数据结构,将`pid`字段从`16bit`扩展为`24bit`;
|
||||
1. 改进`gconv`/`gmap`/`garray`,增加若干操作方法;
|
||||
1. 改进`gview`模板引擎中的`date`内置函数,当给定的时间戳为空时打印当前的系统时间;
|
||||
1. 改进`gview`模板引擎中,当打印的变量不存在时,显示为空(标准库默认显示为`<no value>`);
|
||||
1. 改进`WebServer`,去掉`HANGUP`的信号监听,避免程序通过`nohup`运行时产生异常退出问题;
|
||||
1. 改进`gcache`性能,并完善基准测试;
|
||||
|
||||
## Bug Fix
|
||||
1. 修复`gcache`在非LRU特性开启时的缓存关闭资源竞争问题,并修复`doSetWithLockCheck`内部方法的返回值问题;
|
||||
1. 修复`grand.intn`内部方法在`x86`架构下的随机数位溢出问题;
|
||||
1. 修复`gbinary`中`Int`方法针对`[]byte`参数长度自动匹配造成的字节长度溢出问题;
|
||||
1. 修复`gjson`由于官方标准库`json`不支持`map[interface{}]*`类型造成的Go变量编码问题;
|
||||
1. 修复`garray`中部分方法的数据竞争问题,修复二分查找排序问题;
|
||||
1. 修复`ghttp.Request.GetVar`方法获取参数问题;
|
||||
1. 修复`gform`的数据库连接池不起作用的问题;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# `v1.2.11` (2018-11-26)
|
||||
## 新特性
|
||||
1. `ORM`新增对`SQLServer`及`Oracle`的支持([https://gfer.me/database/orm/database](https://gfer.me/database/orm/database));
|
||||
|
||||
8
TODO.MD
8
TODO.MD
@ -42,6 +42,14 @@
|
||||
1. gtcp提供简便的包发送/接收方法(SendPkg/RecvPkg)以解决常见的TCP通信粘包问题,并完善文档(参考:https://www.cnblogs.com/kex1n/p/6502002.html);
|
||||
1. gfile对于文件的读写强行使用了gfpool,在某些场景下不合适,需要考虑剥离开,并为开发者提供单独的指针池文件操作特性;
|
||||
1. 路由增加不区分大小写得匹配方式;
|
||||
1. str_ireplace: http://php.net/manual/en/function.str-ireplace.php
|
||||
1. strpos/stripos/strrpos/strripos: http://php.net/manual/en/function.stripos.php
|
||||
1. 改进WebServer获取POST参数处理逻辑,当提交非form数据时,例如json数据,针对某些方法可以直接解析;
|
||||
1. WebServer增加可选择的路由覆盖配置,默认情况下不覆盖;
|
||||
1. gkafka这个包比较重,未来从框架中剥离出来;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -153,16 +153,14 @@ func (a *SortedIntArray) binSearch(value int, lock bool) (index int, result int)
|
||||
max := len(a.array) - 1
|
||||
mid := 0
|
||||
cmp := -2
|
||||
for {
|
||||
for min <= max {
|
||||
mid = int((min + max) / 2)
|
||||
cmp = a.compareFunc(value, a.array[mid])
|
||||
switch cmp {
|
||||
case -1 : max = mid - 1
|
||||
case 0 :
|
||||
case 1 : min = mid + 1
|
||||
}
|
||||
if cmp == 0 || min >= max {
|
||||
break
|
||||
case 0 :
|
||||
return mid, cmp
|
||||
}
|
||||
}
|
||||
return mid, cmp
|
||||
|
||||
@ -146,16 +146,14 @@ func (a *SortedArray) binSearch(value interface{}, lock bool)(index int, result
|
||||
max := len(a.array) - 1
|
||||
mid := 0
|
||||
cmp := -2
|
||||
for {
|
||||
for min <= max {
|
||||
mid = int((min + max) / 2)
|
||||
cmp = a.compareFunc(value, a.array[mid])
|
||||
switch cmp {
|
||||
case -1 : max = mid - 1
|
||||
case 0 :
|
||||
case 1 : min = mid + 1
|
||||
}
|
||||
if cmp == 0 || min >= max {
|
||||
break
|
||||
case 0 :
|
||||
return mid, cmp
|
||||
}
|
||||
}
|
||||
return mid, cmp
|
||||
|
||||
@ -147,16 +147,14 @@ func (a *SortedStringArray) binSearch(value string, lock bool) (index int, resul
|
||||
max := len(a.array) - 1
|
||||
mid := 0
|
||||
cmp := -2
|
||||
for {
|
||||
for min <= max {
|
||||
mid = int((min + max) / 2)
|
||||
cmp = a.compareFunc(value, a.array[mid])
|
||||
switch cmp {
|
||||
case -1 : max = mid - 1
|
||||
case 0 :
|
||||
case 1 : min = mid + 1
|
||||
}
|
||||
if cmp == 0 || min >= max {
|
||||
break
|
||||
case 0 :
|
||||
return mid, cmp
|
||||
}
|
||||
}
|
||||
return mid, cmp
|
||||
|
||||
@ -9,18 +9,76 @@
|
||||
package garray_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/garray"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
func TestArray_Unique(t *testing.T) {
|
||||
func Test_IntArray_Unique(t *testing.T) {
|
||||
expect := []int{1, 2, 3, 4, 5, 6}
|
||||
array := garray.NewIntArray(0, 0)
|
||||
array.Append(1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6)
|
||||
array.Unique()
|
||||
if fmt.Sprint(array.Slice()) != fmt.Sprint(expect) {
|
||||
t.Errorf("get: %v, expect: %v\n", array.Slice(), expect)
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedIntArray1(t *testing.T) {
|
||||
expect := []int{0,1,2,3,4,5,6,7,8,9,10}
|
||||
array := garray.NewSortedIntArray(0)
|
||||
for i := 10; i > -1; i-- {
|
||||
array.Add(i)
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedIntArray2(t *testing.T) {
|
||||
expect := []int{0,1,2,3,4,5,6,7,8,9,10}
|
||||
array := garray.NewSortedIntArray(0)
|
||||
for i := 0; i <= 10; i++ {
|
||||
array.Add(i)
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedStringArray1(t *testing.T) {
|
||||
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
|
||||
array := garray.NewSortedStringArray(0)
|
||||
for i := 10; i > -1; i-- {
|
||||
array.Add(gconv.String(i))
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedStringArray2(t *testing.T) {
|
||||
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
|
||||
array := garray.NewSortedStringArray(0)
|
||||
for i := 0; i <= 10; i++ {
|
||||
array.Add(gconv.String(i))
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedArray1(t *testing.T) {
|
||||
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
|
||||
array := garray.NewSortedArray(0, func(v1, v2 interface{}) int {
|
||||
return strings.Compare(gconv.String(v1), gconv.String(v2))
|
||||
})
|
||||
for i := 10; i > -1; i-- {
|
||||
array.Add(gconv.String(i))
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
|
||||
func Test_SortedArray2(t *testing.T) {
|
||||
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
|
||||
array := garray.NewSortedArray(0, func(v1, v2 interface{}) int {
|
||||
return strings.Compare(gconv.String(v1), gconv.String(v2))
|
||||
})
|
||||
for i := 0; i <= 10; i++ {
|
||||
array.Add(gconv.String(i))
|
||||
}
|
||||
gtest.Assert(array.Slice(), expect)
|
||||
}
|
||||
@ -10,6 +10,7 @@ package gvar
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/container/gtype"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"time"
|
||||
)
|
||||
@ -92,6 +93,8 @@ func (v *Var) Interfaces() []interface{} { return gconv.Interfaces(v.Val()
|
||||
func (v *Var) Time(format...string) time.Time { return gconv.Time(v.Val(), format...) }
|
||||
func (v *Var) TimeDuration() time.Duration { return gconv.TimeDuration(v.Val()) }
|
||||
|
||||
func (v *Var) GTime(format...string) *gtime.Time { return gconv.GTime(v.Val(), format...) }
|
||||
|
||||
// 将变量转换为对象,注意 objPointer 参数必须为struct指针
|
||||
func (v *Var) Struct(objPointer interface{}, attrMapping...map[string]string) error {
|
||||
return gconv.Struct(v.Val(), objPointer, attrMapping...)
|
||||
|
||||
@ -6,7 +6,10 @@
|
||||
|
||||
package gvar
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 只读变量接口
|
||||
type VarRead interface {
|
||||
@ -34,5 +37,6 @@ type VarRead interface {
|
||||
Interfaces() []interface{}
|
||||
Time(format ...string) time.Time
|
||||
TimeDuration() time.Duration
|
||||
GTime(format...string) *gtime.Time
|
||||
Struct(objPointer interface{}, attrMapping ...map[string]string) error
|
||||
}
|
||||
@ -1,11 +1,11 @@
|
||||
package gdes
|
||||
package gdes_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/encoding/gdes"
|
||||
"gitee.com/johng/gf/g/crypto/gdes"
|
||||
)
|
||||
|
||||
func TestDesECB(t *testing.T){
|
||||
|
||||
@ -35,6 +35,7 @@ func EncryptFile(path string) string {
|
||||
if e != nil {
|
||||
return ""
|
||||
}
|
||||
defer f.Close()
|
||||
h := md5.New()
|
||||
_, e = io.Copy(h, f)
|
||||
if e != nil {
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
// 数据库ORM.
|
||||
// Package gdb provides ORM features for popular relationship databases/数据库ORM.
|
||||
// 默认内置支持MySQL, 其他数据库需要手动import对应的数据库引擎第三方包.
|
||||
package gdb
|
||||
|
||||
@ -12,7 +12,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/gmap"
|
||||
"gitee.com/johng/gf/g/container/gring"
|
||||
"gitee.com/johng/gf/g/container/gtype"
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
@ -22,39 +21,42 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
OPTION_INSERT = 0
|
||||
OPTION_REPLACE = 1
|
||||
OPTION_SAVE = 2
|
||||
OPTION_IGNORE = 3
|
||||
)
|
||||
|
||||
// 数据库操作接口
|
||||
type Link interface {
|
||||
// 打开数据库连接,建立数据库操作对象
|
||||
Open(c *ConfigNode) (*sql.DB, error)
|
||||
type DB interface {
|
||||
// 建立数据库连接方法(开发者一般不需要直接调用)
|
||||
Open(config *ConfigNode) (*sql.DB, error)
|
||||
|
||||
// SQL操作方法
|
||||
Query(q string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(q string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(q string) (*sql.Stmt, error)
|
||||
// SQL操作方法 API
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error)
|
||||
|
||||
// 内部实现API的方法(不同数据库可覆盖这些方法实现自定义的操作)
|
||||
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
|
||||
doPrepare(link dbLink, query string) (*sql.Stmt, error)
|
||||
doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error)
|
||||
doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error)
|
||||
doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error)
|
||||
doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error)
|
||||
|
||||
// 数据库查询
|
||||
GetAll(q string, args ...interface{}) (Result, error)
|
||||
GetOne(q string, args ...interface{}) (Record, error)
|
||||
GetValue(q string, args ...interface{}) (Value, error)
|
||||
GetAll(query string, args ...interface{}) (Result, error)
|
||||
GetOne(query string, args ...interface{}) (Record, error)
|
||||
GetValue(query string, args ...interface{}) (Value, error)
|
||||
GetCount(query string, args ...interface{}) (int, error)
|
||||
GetStruct(obj interface{}, query string, args ...interface{}) error
|
||||
|
||||
// Ping
|
||||
// 创建底层数据库master/slave链接对象
|
||||
Master() (*sql.DB, error)
|
||||
Slave() (*sql.DB, error)
|
||||
|
||||
// Ping
|
||||
PingMaster() error
|
||||
PingSlave() error
|
||||
|
||||
// 连接属性设置
|
||||
SetMaxIdleConns(n int)
|
||||
SetMaxOpenConns(n int)
|
||||
SetConnMaxLifetime(n int)
|
||||
|
||||
// 开启事务操作
|
||||
Begin() (*Tx, error)
|
||||
Begin() (*TX, error)
|
||||
|
||||
// 数据表插入/更新/保存操作
|
||||
Insert(table string, data Map) (sql.Result, error)
|
||||
@ -74,25 +76,40 @@ type Link interface {
|
||||
Table(tables string) *Model
|
||||
From(tables string) *Model
|
||||
|
||||
// 内部方法
|
||||
insert(table string, data Map, option uint8) (sql.Result, error)
|
||||
batchInsert(table string, list List, batch int, option uint8) (sql.Result, error)
|
||||
// 设置管理
|
||||
SetDebug(debug bool)
|
||||
SetSchema(schema string)
|
||||
GetQueriedSqls() []*Sql
|
||||
PrintQueriedSqls()
|
||||
SetMaxIdleConns(n int)
|
||||
SetMaxOpenConns(n int)
|
||||
SetConnMaxLifetime(n int)
|
||||
|
||||
getQuoteCharLeft() string
|
||||
getQuoteCharRight() string
|
||||
handleSqlBeforeExec(q *string) *string
|
||||
// 内部方法接口
|
||||
getCache() (*gcache.Cache)
|
||||
getChars() (charLeft string, charRight string)
|
||||
getDebug() bool
|
||||
filterFields(table string, data map[string]interface{}) map[string]interface{}
|
||||
getTableFields(table string) (map[string]string, error)
|
||||
handleSqlBeforeExec(sql string) string
|
||||
}
|
||||
|
||||
// 执行底层数据库操作的核心接口
|
||||
type dbLink interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// 数据库链接对象
|
||||
type Db struct {
|
||||
link Link // 底层数据库类型管理对象
|
||||
type dbBase struct {
|
||||
db DB // 数据库对象
|
||||
group string // 配置分组名称
|
||||
charl string // SQL安全符号(左)
|
||||
charr string // SQL安全符号(右)
|
||||
debug *gtype.Bool // (默认关闭)是否开启调试模式,当开启时会启用一些调试特性
|
||||
sqls *gring.Ring // (debug=true时有效)已执行的SQL列表
|
||||
cache *gcache.Cache // 数据库缓存,包括底层连接池对象缓存及查询缓存;需要注意的是,事务查询不支持查询缓存
|
||||
maxIdleConnCount *gtype.Int // 连接池最大限制的连接数
|
||||
schema *gtype.String // 手动切换的数据库名称
|
||||
maxIdleConnCount *gtype.Int // 连接池最大限制的连接数
|
||||
maxOpenConnCount *gtype.Int // 连接池最大打开的连接数
|
||||
maxConnLifetime *gtype.Int // (单位秒)连接对象可重复使用的时间长度
|
||||
}
|
||||
@ -104,7 +121,7 @@ type Sql struct {
|
||||
Error error // 执行结果(nil为成功)
|
||||
Start int64 // 执行开始时间(毫秒)
|
||||
End int64 // 执行结束时间(毫秒)
|
||||
Func string // 执行方法名称
|
||||
Func string // 执行方法
|
||||
}
|
||||
|
||||
// 返回数据表记录值
|
||||
@ -117,27 +134,22 @@ type Record map[string]Value
|
||||
type Result []Record
|
||||
|
||||
// 关联数组,绑定一条数据表记录(使用别名)
|
||||
type Map = map[string]interface{}
|
||||
type Map = map[string]interface{}
|
||||
|
||||
// 关联数组列表(索引从0开始的数组),绑定多条记录(使用别名)
|
||||
type List = []Map
|
||||
|
||||
var (
|
||||
// 支持的数据库类型map
|
||||
driverMap = make(map[string]interface{})
|
||||
// 数据库查询缓存对象map,使用数据库连接名称作为键名,键值为查询缓存对象
|
||||
dbCaches = gmap.NewStringInterfaceMap()
|
||||
const (
|
||||
OPTION_INSERT = 0
|
||||
OPTION_REPLACE = 1
|
||||
OPTION_SAVE = 2
|
||||
OPTION_IGNORE = 3
|
||||
// 默认的连接池连接存活时间(秒)
|
||||
gDEFAULT_CONN_MAX_LIFE_TIME = 30
|
||||
)
|
||||
func init() {
|
||||
driverMap["mysql"] = linkMysql
|
||||
driverMap["oracle"] = linkOracle
|
||||
driverMap["sqlite"] = linkSqlite
|
||||
driverMap["pgsql"] = linkPgsql
|
||||
driverMap["mssql"] = linkMssql
|
||||
}
|
||||
|
||||
// 使用默认/指定分组配置进行连接,数据库集群配置项:default
|
||||
func New(groupName ...string) (*Db, error) {
|
||||
func New(groupName ...string) (db DB, err error) {
|
||||
group := config.d
|
||||
if len(groupName) > 0 {
|
||||
group = groupName[0]
|
||||
@ -150,24 +162,30 @@ func New(groupName ...string) (*Db, error) {
|
||||
}
|
||||
if _, ok := config.c[group]; ok {
|
||||
if node, err := getConfigNodeByGroup(group, true); err == nil {
|
||||
link, err := getLinkByType(node.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db := &Db {
|
||||
link : link,
|
||||
base := &dbBase {
|
||||
group : group,
|
||||
charl : link.getQuoteCharLeft(),
|
||||
charr : link.getQuoteCharRight(),
|
||||
debug : gtype.NewBool(),
|
||||
cache : gcache.New(),
|
||||
schema : gtype.NewString(),
|
||||
maxIdleConnCount : gtype.NewInt(),
|
||||
maxOpenConnCount : gtype.NewInt(),
|
||||
maxConnLifetime : gtype.NewInt(),
|
||||
maxConnLifetime : gtype.NewInt(gDEFAULT_CONN_MAX_LIFE_TIME),
|
||||
}
|
||||
db.cache = dbCaches.GetOrSetFuncLock(group, func() interface{} {
|
||||
return gcache.New()
|
||||
}).(*gcache.Cache)
|
||||
return db, nil
|
||||
switch node.Type {
|
||||
case "mysql":
|
||||
base.db = &dbMysql{dbBase : base}
|
||||
case "pgsql":
|
||||
base.db = &dbPgsql{dbBase : base}
|
||||
case "mssql":
|
||||
base.db = &dbMssql{dbBase : base}
|
||||
case "sqlite":
|
||||
base.db = &dbSqlite{dbBase : base}
|
||||
case "oracle":
|
||||
base.db = &dbOracle{dbBase : base}
|
||||
default:
|
||||
return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type))
|
||||
}
|
||||
return base.db, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
@ -219,6 +237,13 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
|
||||
for i := 0; i < len(cg); i++ {
|
||||
total += cg[i].Priority * 100
|
||||
}
|
||||
// 如果total为0表示所有连接都没有配置priority属性,那么默认都是1
|
||||
if total == 0 {
|
||||
for i := 0; i < len(cg); i++ {
|
||||
cg[i].Priority = 1
|
||||
total += cg[i].Priority * 100
|
||||
}
|
||||
}
|
||||
// 不能取到末尾的边界点
|
||||
r := grand.Rand(0, total)
|
||||
if r > 0 {
|
||||
@ -238,51 +263,36 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据配置的数据库;类型获得Link接口对象
|
||||
func getLinkByType(dbType string) (Link, error) {
|
||||
if dblink, ok := driverMap[dbType]; ok == false {
|
||||
return nil, errors.New(fmt.Sprintf("unsupported db type '%s'", dbType))
|
||||
} else {
|
||||
return dblink.(Link), nil
|
||||
}
|
||||
}
|
||||
|
||||
// 获得底层数据库链接对象
|
||||
func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
// 负载均衡
|
||||
node, err := getConfigNodeByGroup(db.group, master)
|
||||
node, err := getConfigNodeByGroup(bs.group, master)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 类型对象
|
||||
link, err := getLinkByType(node.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 默认值设定
|
||||
if node.Charset == "" {
|
||||
node.Charset = "utf8"
|
||||
}
|
||||
// 检查缓存连接池对象
|
||||
cacheKey := node.String()
|
||||
if v := db.cache.Get(cacheKey); v != nil {
|
||||
return v.(*sql.DB), nil
|
||||
}
|
||||
v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = link.Open(node)
|
||||
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = bs.db.Open(node)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if n := db.maxIdleConnCount.Val(); n > 0 {
|
||||
if n := bs.maxIdleConnCount.Val(); n > 0 {
|
||||
sqlDb.SetMaxIdleConns(n)
|
||||
} else if node.MaxIdleConnCount > 0 {
|
||||
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
|
||||
}
|
||||
|
||||
if n := db.maxOpenConnCount.Val(); n > 0 {
|
||||
if n := bs.maxOpenConnCount.Val(); n > 0 {
|
||||
sqlDb.SetMaxOpenConns(n)
|
||||
} else if node.MaxOpenConnCount > 0 {
|
||||
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
|
||||
}
|
||||
|
||||
if n := db.maxConnLifetime.Val(); n > 0 {
|
||||
if n := bs.maxConnLifetime.Val(); n > 0 {
|
||||
sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second)
|
||||
} else if node.MaxConnLifetime > 0 {
|
||||
sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second)
|
||||
@ -292,15 +302,24 @@ func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
if v != nil && sqlDb == nil {
|
||||
sqlDb = v.(*sql.DB)
|
||||
}
|
||||
// 是否手动选择数据库
|
||||
if v := bs.schema.Val(); v != "" {
|
||||
sqlDb.Exec("USE " + v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 切换操作的数据库(注意该切换是全局的)
|
||||
func (bs *dbBase) SetSchema(schema string) {
|
||||
bs.schema.Set(schema)
|
||||
}
|
||||
|
||||
// 创建底层数据库master链接对象
|
||||
func (db *Db) Master() (*sql.DB, error) {
|
||||
return db.getSqlDb(true)
|
||||
func (bs *dbBase) Master() (*sql.DB, error) {
|
||||
return bs.getSqlDb(true)
|
||||
}
|
||||
|
||||
// 创建底层数据库slave链接对象
|
||||
func (db *Db) Slave() (*sql.DB, error) {
|
||||
return db.getSqlDb(false)
|
||||
func (bs *dbBase) Slave() (*sql.DB, error) {
|
||||
return bs.getSqlDb(false)
|
||||
}
|
||||
|
||||
@ -8,39 +8,29 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"strings"
|
||||
"reflect"
|
||||
"database/sql"
|
||||
"gitee.com/johng/gf/g/util/gstr"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/container/gring"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/os/gcache"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
gDEFAULT_DEBUG_SQL_LENGTH = 1000 // 默认调试模式下记录的SQL条数
|
||||
)
|
||||
|
||||
// 是否开启调试服务
|
||||
func (db *Db) SetDebug(debug bool) {
|
||||
db.debug.Set(debug)
|
||||
if debug && db.sqls == nil {
|
||||
db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取已经执行的SQL列表(仅在debug=true时有效)
|
||||
func (db *Db) GetQueriedSqls() []*Sql {
|
||||
if db.sqls == nil {
|
||||
func (bs *dbBase) GetQueriedSqls() []*Sql {
|
||||
if bs.sqls == nil {
|
||||
return nil
|
||||
}
|
||||
sqls := make([]*Sql, 0)
|
||||
db.sqls.Prev()
|
||||
db.sqls.RLockIteratorPrev(func(value interface{}) bool {
|
||||
bs.sqls.Prev()
|
||||
bs.sqls.RLockIteratorPrev(func(value interface{}) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
@ -51,8 +41,8 @@ func (db *Db) GetQueriedSqls() []*Sql {
|
||||
}
|
||||
|
||||
// 打印已经执行的SQL列表(仅在debug=true时有效)
|
||||
func (db *Db) PrintQueriedSqls() {
|
||||
sqls := db.GetQueriedSqls()
|
||||
func (bs *dbBase) PrintQueriedSqls() {
|
||||
sqls := bs.GetQueriedSqls()
|
||||
for k, v := range sqls {
|
||||
fmt.Println(len(sqls) - k, ":")
|
||||
fmt.Println(" Sql :", v.Sql)
|
||||
@ -61,143 +51,110 @@ func (db *Db) PrintQueriedSqls() {
|
||||
fmt.Println(" Start:", gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"))
|
||||
fmt.Println(" End :", gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"))
|
||||
fmt.Println(" Cost :", v.End - v.Start, "ms")
|
||||
fmt.Println(" Func :", v.Func)
|
||||
}
|
||||
}
|
||||
|
||||
// 打印SQL对象(仅在debug=true时有效)
|
||||
func (db *Db) printSql(v *Sql) {
|
||||
s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args,
|
||||
gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"),
|
||||
gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"),
|
||||
v.End - v.Start, v.Func,
|
||||
)
|
||||
if v.Error != nil {
|
||||
s += "\nError: " + v.Error.Error()
|
||||
glog.Backtrace(true, 2).Error(s)
|
||||
} else {
|
||||
glog.Debug(s)
|
||||
}
|
||||
}
|
||||
|
||||
// 数据库sql查询操作,主要执行查询
|
||||
func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
var err error
|
||||
var rows *sql.Rows
|
||||
var slave *sql.DB
|
||||
slave, err = db.Slave();
|
||||
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := bs.db.Slave()
|
||||
if err != nil {
|
||||
return nil,err
|
||||
}
|
||||
p := db.link.handleSqlBeforeExec(&query)
|
||||
if db.debug.Val() {
|
||||
militime1 := gtime.Millisecond()
|
||||
rows, err = slave.Query(*p, args ...)
|
||||
militime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : *p,
|
||||
return bs.db.doQuery(link, query, args...)
|
||||
}
|
||||
|
||||
// 数据库sql查询操作,主要执行查询
|
||||
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
rows, err = link.Query(query, args...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql {
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : militime1,
|
||||
End : militime2,
|
||||
Func : "DB:Query",
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
db.sqls.Put(s)
|
||||
db.printSql(s)
|
||||
bs.sqls.Put(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
rows, err = slave.Query(*p, args ...)
|
||||
rows, err = link.Query(query, args ...)
|
||||
}
|
||||
if err == nil {
|
||||
return rows, nil
|
||||
} else {
|
||||
err = db.formatError(err, p, args...)
|
||||
err = formatError(err, query, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
var err error
|
||||
var result sql.Result
|
||||
var master *sql.DB
|
||||
master, err = db.Master();
|
||||
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil,err
|
||||
}
|
||||
p := db.link.handleSqlBeforeExec(&query)
|
||||
if db.debug.Val() {
|
||||
militime1 := gtime.Millisecond()
|
||||
result, err = master.Exec(*p, args ...)
|
||||
militime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : *p,
|
||||
return bs.db.doExec(link, query, args...)
|
||||
}
|
||||
|
||||
// 执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
result, err = link.Exec(query, args ...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : militime1,
|
||||
End : militime2,
|
||||
Func : "DB:Exec",
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
db.sqls.Put(s)
|
||||
db.printSql(s)
|
||||
bs.sqls.Put(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
result, err = master.Exec(*p, args ...)
|
||||
result, err = link.Exec(query, args ...)
|
||||
}
|
||||
return result, db.formatError(err, p, args...)
|
||||
return result, formatError(err, query, args...)
|
||||
}
|
||||
|
||||
// 格式化错误信息
|
||||
func (db *Db) formatError(err error, query *string, args ...interface{}) error {
|
||||
if err != nil {
|
||||
errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error())
|
||||
errstr += fmt.Sprintf("DB QUERY: %s\n", *query)
|
||||
if len(args) > 0 {
|
||||
errstr += fmt.Sprintf("DB PARAM: %v\n", args)
|
||||
// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上
|
||||
func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
|
||||
err := (error)(nil)
|
||||
link := (dbLink)(nil)
|
||||
if len(execOnMaster) > 0 && execOnMaster[0] {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if link, err = bs.db.Slave(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = errors.New(errstr)
|
||||
}
|
||||
return err
|
||||
return bs.db.doPrepare(link, query)
|
||||
}
|
||||
|
||||
// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
|
||||
return link.Prepare(query)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果集,以列表结构返回
|
||||
func (db *Db) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
// 执行sql
|
||||
rows, err := db.Query(query, args ...)
|
||||
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
rows, err := bs.Query(query, args ...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
// 列名称列表
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回结构组装
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
scanArgs := make([]interface{}, len(values))
|
||||
records := make(Result, 0)
|
||||
for i := range values {
|
||||
scanArgs[i] = &values[i]
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
row := make(Record)
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, col := range values {
|
||||
v := make([]byte, len(col))
|
||||
copy(v, col)
|
||||
row[columns[i]] = gvar.New(v, false)
|
||||
}
|
||||
records = append(records, row)
|
||||
}
|
||||
return records, nil
|
||||
defer rows.Close()
|
||||
return rowsToResult(rows)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,以关联数组结构返回
|
||||
func (db *Db) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := db.GetAll(query, args ...)
|
||||
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := bs.GetAll(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -208,18 +165,17 @@ func (db *Db) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中
|
||||
func (db *Db) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := db.GetOne(query, args...)
|
||||
func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := bs.GetOne(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return one.ToStruct(obj)
|
||||
}
|
||||
|
||||
|
||||
// 数据库查询,获取查询字段值
|
||||
func (db *Db) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := db.GetOne(query, args ...)
|
||||
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := bs.GetOne(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -230,44 +186,20 @@ func (db *Db) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询数量
|
||||
func (db *Db) GetCount(query string, args ...interface{}) (int, error) {
|
||||
val, err := db.GetValue(query, args ...)
|
||||
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
}
|
||||
value, err := bs.GetValue(query, args ...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return gconv.Int(val), nil
|
||||
}
|
||||
|
||||
// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作
|
||||
func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
|
||||
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
|
||||
if condition != nil {
|
||||
s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition))
|
||||
}
|
||||
if len(groupBy) > 0 {
|
||||
s += fmt.Sprintf("GROUP BY %s ", groupBy)
|
||||
}
|
||||
if len(orderBy) > 0 {
|
||||
s += fmt.Sprintf("ORDER BY %s ", orderBy)
|
||||
}
|
||||
if limit > 0 {
|
||||
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
|
||||
}
|
||||
return db.GetAll(s, args ... )
|
||||
}
|
||||
|
||||
// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (db *Db) Prepare(query string) (*sql.Stmt, error) {
|
||||
if master, err := db.Master(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return master.Prepare(query)
|
||||
}
|
||||
return value.Int(), nil
|
||||
}
|
||||
|
||||
// ping一下,判断或保持数据库链接(master)
|
||||
func (db *Db) PingMaster() error {
|
||||
if master, err := db.Master(); err != nil {
|
||||
func (bs *dbBase) PingMaster() error {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return master.Ping()
|
||||
@ -275,8 +207,8 @@ func (db *Db) PingMaster() error {
|
||||
}
|
||||
|
||||
// ping一下,判断或保持数据库链接(slave)
|
||||
func (db *Db) PingSlave() error {
|
||||
if slave, err := db.Slave(); err != nil {
|
||||
func (bs *dbBase) PingSlave() error {
|
||||
if slave, err := bs.db.Slave(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return slave.Ping()
|
||||
@ -285,13 +217,13 @@ func (db *Db) PingSlave() error {
|
||||
|
||||
// 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略
|
||||
// 只有在tx.Commit/tx.Rollback时,链接会自动Close
|
||||
func (db *Db) Begin() (*Tx, error) {
|
||||
if master, err := db.Master(); err != nil {
|
||||
func (bs *dbBase) Begin() (*TX, error) {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
if tx, err := master.Begin(); err == nil {
|
||||
return &Tx {
|
||||
db : db,
|
||||
return &TX {
|
||||
db : bs.db,
|
||||
tx : tx,
|
||||
master : master,
|
||||
}, nil
|
||||
@ -301,17 +233,19 @@ func (db *Db) Begin() (*Tx, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 根据insert选项获得操作名称
|
||||
func (db *Db) getInsertOperationByOption(option uint8) string {
|
||||
oper := "INSERT"
|
||||
switch option {
|
||||
case OPTION_REPLACE:
|
||||
oper = "REPLACE"
|
||||
case OPTION_SAVE:
|
||||
case OPTION_IGNORE:
|
||||
oper = "INSERT IGNORE"
|
||||
}
|
||||
return oper
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (bs *dbBase) Save(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// insert、replace, save, ignore操作
|
||||
@ -319,95 +253,102 @@ func (db *Db) getInsertOperationByOption(option uint8) string {
|
||||
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
|
||||
func (db *Db) insert(table string, data Map, option uint8) (sql.Result, error) {
|
||||
func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) {
|
||||
var fields []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
charl, charr := bs.db.getChars()
|
||||
for k, v := range data {
|
||||
fields = append(fields, db.charl + k + db.charr)
|
||||
fields = append(fields, charl + k + charr)
|
||||
values = append(values, "?")
|
||||
params = append(params, v)
|
||||
}
|
||||
operation := db.getInsertOperationByOption(option)
|
||||
operation := getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for k, _ := range data {
|
||||
updates = append(updates,
|
||||
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
db.charl, k, db.charr,
|
||||
db.charl, k, db.charr,
|
||||
charl, k, charr,
|
||||
charl, k, charr,
|
||||
),
|
||||
)
|
||||
}
|
||||
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
return db.Exec(
|
||||
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","),
|
||||
updatestr),
|
||||
params...
|
||||
)
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","), updatestr),
|
||||
params...)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (db *Db) Insert(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_INSERT)
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (db *Db) Replace(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_REPLACE)
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (db *Db) Save(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_SAVE)
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// 批量写入数据
|
||||
func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
|
||||
func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var bvalues []string
|
||||
var params []interface{}
|
||||
var result sql.Result
|
||||
var size = len(list)
|
||||
// 判断长度
|
||||
if size < 1 {
|
||||
if len(list) < 1 {
|
||||
return result, errors.New("empty data list")
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// 首先获取字段名称及记录长度
|
||||
for k, _ := range list[0] {
|
||||
keys = append(keys, k)
|
||||
values = append(values, "?")
|
||||
}
|
||||
keyStr := db.charl + strings.Join(keys, db.charl + "," + db.charr) + db.charr
|
||||
charl, charr := bs.db.getChars()
|
||||
keyStr := charl + strings.Join(keys, charl + "," + charr) + charr
|
||||
valueHolderStr := "(" + strings.Join(values, ",") + ")"
|
||||
// 操作判断
|
||||
operation := db.getInsertOperationByOption(option)
|
||||
operation := getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for _, k := range keys {
|
||||
updates = append(updates,
|
||||
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
db.charl, k, db.charr,
|
||||
db.charl, k, db.charr,
|
||||
charl, k, charr,
|
||||
charl, k, charr,
|
||||
),
|
||||
)
|
||||
}
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
// 构造批量写入数据格式(注意map的遍历是无序的)
|
||||
for i := 0; i < size; i++ {
|
||||
for i := 0; i < len(list); i++ {
|
||||
for _, k := range keys {
|
||||
params = append(params, list[i][k])
|
||||
}
|
||||
bvalues = append(bvalues, valueHolderStr)
|
||||
if len(bvalues) == batch {
|
||||
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
@ -421,7 +362,7 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql
|
||||
}
|
||||
// 处理最后不构成指定批量的数据
|
||||
if len(bvalues) > 0 {
|
||||
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
@ -433,32 +374,28 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (db *Db) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (db *Db) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (db *Db) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_SAVE)
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doUpdate(link, table, data, condition, args ...)
|
||||
}
|
||||
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (db *Db) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
var params []interface{}
|
||||
var updates string
|
||||
refValue := reflect.ValueOf(data)
|
||||
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
params := ([]interface{})(nil)
|
||||
updates := ""
|
||||
charl, charr := bs.db.getChars()
|
||||
refValue := reflect.ValueOf(data)
|
||||
if refValue.Kind() == reflect.Map {
|
||||
var fields []string
|
||||
keys := refValue.MapKeys()
|
||||
for _, k := range keys {
|
||||
fields = append(fields, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr))
|
||||
fields = append(fields, fmt.Sprintf("%s%s%s=?", charl, k, charr))
|
||||
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
|
||||
}
|
||||
updates = strings.Join(fields, ",")
|
||||
@ -468,34 +405,65 @@ func (db *Db) Update(table string, data interface{}, condition interface{}, args
|
||||
for _, v := range args {
|
||||
params = append(params, gconv.String(v))
|
||||
}
|
||||
return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...)
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
newWhere, newArgs := formatCondition(condition, params)
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (db *Db) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...)
|
||||
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doDelete(link, table, condition, args ...)
|
||||
}
|
||||
|
||||
// 格式化SQL查询条件
|
||||
func (db *Db) formatCondition(condition interface{}) (where string) {
|
||||
if reflect.ValueOf(condition).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(condition).MapKeys()
|
||||
vs := reflect.ValueOf(condition)
|
||||
for _, k := range ks {
|
||||
key := gconv.String(k.Interface())
|
||||
value := gconv.String(vs.MapIndex(k).Interface())
|
||||
isNum := gstr.IsNumeric(value)
|
||||
if len(where) > 0 {
|
||||
where += " AND "
|
||||
}
|
||||
if isNum || value == "?" {
|
||||
where += key + "=" + value
|
||||
} else {
|
||||
where += key + "='" + value + "'"
|
||||
// CURD操作:删除数据
|
||||
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...)
|
||||
}
|
||||
|
||||
// 获得缓存对象
|
||||
func (bs *dbBase) getCache() *gcache.Cache {
|
||||
return bs.cache
|
||||
}
|
||||
|
||||
// 将map的数据按照fields进行过滤,只保留与表字段同名的数据
|
||||
func (bs *dbBase) filterFields(table string, data map[string]interface{}) map[string]interface{} {
|
||||
if fields, err := bs.db.getTableFields(table); err == nil {
|
||||
for k, _ := range data {
|
||||
if _, ok := fields[k]; !ok {
|
||||
delete(data, k)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
where += gconv.String(condition)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// 获得指定表表的数据结构,构造成map哈希表返回,其中键名为表字段名称,键值暂无用途(默认为字段数据类型).
|
||||
func (bs *dbBase) getTableFields(table string) (fields map[string]string, err error) {
|
||||
// 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署)
|
||||
v := bs.cache.GetOrSetFunc("table_fields_" + table, func() interface{} {
|
||||
result := (Result)(nil)
|
||||
charl, charr := bs.db.getChars()
|
||||
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charl, table, charr))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
fields = make(map[string]string)
|
||||
for _, m := range result {
|
||||
fields[m["Field"].String()] = m["Type"].String()
|
||||
}
|
||||
return fields
|
||||
}, 0)
|
||||
if err == nil {
|
||||
fields = v.(map[string]string)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -9,6 +9,7 @@ package gdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/gring"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -114,6 +115,13 @@ func AddDefaultConfigGroup (nodes ConfigGroup) {
|
||||
AddConfigGroup(DEFAULT_GROUP_NAME, nodes)
|
||||
}
|
||||
|
||||
// 添加一台数据库服务器配置
|
||||
func GetConfig (group string) ConfigGroup {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
return config.c[group]
|
||||
}
|
||||
|
||||
// 设置默认链接的数据库链接配置项(默认是 default)
|
||||
func SetDefaultGroup (groupName string) {
|
||||
config.Lock()
|
||||
@ -122,19 +130,19 @@ func SetDefaultGroup (groupName string) {
|
||||
}
|
||||
|
||||
// 设置数据库连接池中空闲链接的大小
|
||||
func (db *Db) SetMaxIdleConns(n int) {
|
||||
db.maxIdleConnCount.Set(n)
|
||||
func (bs *dbBase) SetMaxIdleConns(n int) {
|
||||
bs.maxIdleConnCount.Set(n)
|
||||
}
|
||||
|
||||
// 设置数据库连接池最大打开的链接数量
|
||||
func (db *Db) SetMaxOpenConns(n int) {
|
||||
db.maxOpenConnCount.Set(n)
|
||||
func (bs *dbBase) SetMaxOpenConns(n int) {
|
||||
bs.maxOpenConnCount.Set(n)
|
||||
}
|
||||
|
||||
// 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃
|
||||
// 如果 d <= 0 表示该链接会一直重复利用
|
||||
func (db *Db) SetConnMaxLifetime(n int) {
|
||||
db.maxConnLifetime.Set(n)
|
||||
func (bs *dbBase) SetConnMaxLifetime(n int) {
|
||||
bs.maxConnLifetime.Set(n)
|
||||
}
|
||||
|
||||
// 节点配置转换为字符串
|
||||
@ -146,4 +154,17 @@ func (node *ConfigNode) String() string {
|
||||
node.Name, node.Type, node.Role, node.Charset,
|
||||
node.MaxIdleConnCount, node.MaxOpenConnCount, node.MaxConnLifetime,
|
||||
)
|
||||
}
|
||||
|
||||
// 是否开启调试服务
|
||||
func (bs *dbBase) SetDebug(debug bool) {
|
||||
bs.debug.Set(debug)
|
||||
if debug && bs.sqls == nil {
|
||||
bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取是否开启调试服务
|
||||
func (bs *dbBase) getDebug() bool {
|
||||
return bs.debug.Val()
|
||||
}
|
||||
158
g/database/gdb/gdb_func.go
Normal file
158
g/database/gdb/gdb_func.go
Normal file
@ -0,0 +1,158 @@
|
||||
// Copyright 2017-2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
"gitee.com/johng/gf/g/util/gstr"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 将数据查询的列表数据*sql.Rows转换为Result类型
|
||||
func rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
// 列名称列表
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回结构组装
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
scanArgs := make([]interface{}, len(values))
|
||||
records := make(Result, 0)
|
||||
for i := range values {
|
||||
scanArgs[i] = &values[i]
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
row := make(Record)
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, col := range values {
|
||||
if col == nil {
|
||||
row[columns[i]] = gvar.New(nil, false)
|
||||
} else {
|
||||
v := make([]byte, len(col))
|
||||
copy(v, col)
|
||||
row[columns[i]] = gvar.New(v, false)
|
||||
}
|
||||
}
|
||||
records = append(records, row)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// 格式化SQL查询条件
|
||||
func formatCondition(where interface{}, args []interface{}) (string, []interface{}) {
|
||||
// 条件字符串处理
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
if reflect.ValueOf(where).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(where).MapKeys()
|
||||
vs := reflect.ValueOf(where)
|
||||
for _, k := range ks {
|
||||
key := gconv.String(k.Interface())
|
||||
value := gconv.String(vs.MapIndex(k).Interface())
|
||||
if buffer.Len() > 0 {
|
||||
buffer.WriteString(" AND ")
|
||||
}
|
||||
if gstr.IsNumeric(value) || value == "?" {
|
||||
buffer.WriteString(key + "=" + value)
|
||||
} else {
|
||||
buffer.WriteString(key + "='" + value + "'")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
buffer.Write(gconv.Bytes(where))
|
||||
}
|
||||
if buffer.Len() == 0 {
|
||||
buffer.WriteString("1")
|
||||
}
|
||||
// 查询条件处理
|
||||
newWhere := buffer.String()
|
||||
newArgs := make([]interface{}, 0)
|
||||
if len(args) > 0 {
|
||||
for index, arg := range args {
|
||||
rv := reflect.ValueOf(arg)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Slice: fallthrough
|
||||
case reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
newArgs = append(newArgs, rv.Index(i).Interface())
|
||||
}
|
||||
counter := 0
|
||||
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
|
||||
counter++
|
||||
if counter == index + 1 {
|
||||
return "?" + strings.Repeat(",?", rv.Len() - 1)
|
||||
}
|
||||
return s
|
||||
})
|
||||
default:
|
||||
newArgs = append(newArgs, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return newWhere, newArgs
|
||||
}
|
||||
|
||||
// 打印SQL对象(仅在debug=true时有效)
|
||||
func printSql(v *Sql) {
|
||||
s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args,
|
||||
gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"),
|
||||
gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"),
|
||||
v.End - v.Start,
|
||||
v.Func,
|
||||
)
|
||||
if v.Error != nil {
|
||||
s += "\nError: " + v.Error.Error()
|
||||
glog.Backtrace(true, 2).Error(s)
|
||||
} else {
|
||||
glog.Debug(s)
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化错误信息
|
||||
func formatError(err error, query string, args ...interface{}) error {
|
||||
if err != nil {
|
||||
errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error())
|
||||
errstr += fmt.Sprintf("DB QUERY: %s\n", query)
|
||||
if len(args) > 0 {
|
||||
errstr += fmt.Sprintf("DB PARAM: %v\n", args)
|
||||
}
|
||||
err = errors.New(errstr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 根据insert选项获得操作名称
|
||||
func getInsertOperationByOption(option int) string {
|
||||
oper := "INSERT"
|
||||
switch option {
|
||||
case OPTION_REPLACE:
|
||||
oper = "REPLACE"
|
||||
case OPTION_SAVE:
|
||||
case OPTION_IGNORE:
|
||||
oper = "INSERT IGNORE"
|
||||
}
|
||||
return oper
|
||||
}
|
||||
@ -12,12 +12,15 @@ import (
|
||||
"database/sql"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 数据库链式操作模型对象
|
||||
type Model struct {
|
||||
tx *Tx // 数据库事务对象
|
||||
db *Db // 数据库操作对象
|
||||
db DB // 数据库操作对象
|
||||
tx *TX // 数据库事务对象
|
||||
tablesInit string // 初始化Model时的表名称(可以是多个)
|
||||
tables string // 数据库操作表
|
||||
fields string // 操作字段
|
||||
where string // 操作条件
|
||||
@ -28,123 +31,213 @@ type Model struct {
|
||||
limit int // 分页条数
|
||||
data interface{} // 操作记录(支持Map/List/string类型)
|
||||
batch int // 批量操作条数
|
||||
filter bool // 是否按照表字段过滤data参数
|
||||
cacheEnabled bool // 当前SQL操作是否开启查询缓存功能
|
||||
cacheTime int // 查询缓存时间
|
||||
cacheName string // 查询缓存名称
|
||||
}
|
||||
|
||||
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (db *Db) Table(tables string) (*Model) {
|
||||
return &Model{
|
||||
db: db,
|
||||
tables: tables,
|
||||
fields: "*",
|
||||
func (bs *dbBase) Table(tables string) (*Model) {
|
||||
return &Model {
|
||||
db : bs.db,
|
||||
tablesInit : tables,
|
||||
tables : tables,
|
||||
fields : "*",
|
||||
}
|
||||
}
|
||||
|
||||
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (db *Db) From(tables string) (*Model) {
|
||||
return db.Table(tables)
|
||||
func (bs *dbBase) From(tables string) (*Model) {
|
||||
return bs.db.Table(tables)
|
||||
}
|
||||
|
||||
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (tx *Tx) Table(tables string) (*Model) {
|
||||
func (tx *TX) Table(tables string) (*Model) {
|
||||
return &Model{
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
tables: tables,
|
||||
db : tx.db,
|
||||
tx : tx,
|
||||
tablesInit : tables,
|
||||
tables : tables,
|
||||
}
|
||||
}
|
||||
|
||||
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (tx *Tx) From(tables string) (*Model) {
|
||||
func (tx *TX) From(tables string) (*Model) {
|
||||
return tx.Table(tables)
|
||||
}
|
||||
|
||||
// 克隆一个当前对象
|
||||
func (md *Model) Clone() *Model {
|
||||
newModel := (*Model)(nil)
|
||||
if md.tx != nil {
|
||||
newModel = md.tx.Table(md.tablesInit)
|
||||
} else {
|
||||
newModel = md.db.Table(md.tablesInit)
|
||||
}
|
||||
*newModel = *md
|
||||
return newModel
|
||||
}
|
||||
|
||||
// 链式操作,左联表
|
||||
func (md *Model) LeftJoin(joinTable string, on string) (*Model) {
|
||||
md.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", joinTable, on)
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", joinTable, on)
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,右联表
|
||||
func (md *Model) RightJoin(joinTable string, on string) (*Model) {
|
||||
md.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", joinTable, on)
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", joinTable, on)
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,内联表
|
||||
func (md *Model) InnerJoin(joinTable string, on string) (*Model) {
|
||||
md.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", joinTable, on)
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", joinTable, on)
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,查询字段
|
||||
func (md *Model) Fields(fields string) (*Model) {
|
||||
md.fields = fields
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.fields = fields
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,过滤字段
|
||||
func (md *Model) Filter() (*Model) {
|
||||
model := md.Clone()
|
||||
model.filter = true
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,condition,支持string & gdb.Map
|
||||
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where = md.db.formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
return md
|
||||
model := md.Clone()
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
model.where = newWhere
|
||||
model.whereArgs = append(model.whereArgs, newArgs...)
|
||||
// 支持 Where("uid", 1)这种格式
|
||||
if len(args) == 1 && strings.Index(model.where , "?") < 0 {
|
||||
model.where += "=?"
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,添加AND条件到Where中
|
||||
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " AND " + md.db.formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
return md
|
||||
model := md.Clone()
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
model.where += " AND " + newWhere
|
||||
model.whereArgs = append(model.whereArgs, newArgs...)
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,添加OR条件到Where中
|
||||
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " OR " + md.db.formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
return md
|
||||
model := md.Clone()
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
model.where += " OR " + newWhere
|
||||
model.whereArgs = append(model.whereArgs, newArgs...)
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,group by
|
||||
func (md *Model) GroupBy(groupBy string) (*Model) {
|
||||
md.groupBy = groupBy
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.groupBy = groupBy
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,order by
|
||||
func (md *Model) OrderBy(orderBy string) (*Model) {
|
||||
md.orderBy = orderBy
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.orderBy = orderBy
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,limit
|
||||
func (md *Model) Limit(start int, limit int) (*Model) {
|
||||
md.start = start
|
||||
md.limit = limit
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.start = start
|
||||
model.limit = limit
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,翻页
|
||||
// @author ymrjqyy
|
||||
func (md *Model) ForPage(page, limit int) (*Model) {
|
||||
md.start = (page - 1) * limit
|
||||
md.limit = limit
|
||||
return md
|
||||
model := md.Clone()
|
||||
model.start = (page - 1) * limit
|
||||
model.limit = limit
|
||||
return model
|
||||
}
|
||||
|
||||
// 设置批处理的大小
|
||||
func (md *Model) Batch(batch int) *Model {
|
||||
model := md.Clone()
|
||||
model.batch = batch
|
||||
return model
|
||||
}
|
||||
|
||||
// 查询缓存/清除缓存操作,需要注意的是,事务查询不支持缓存。
|
||||
// 当time < 0时表示清除缓存, time=0时表示不过期, time > 0时表示过期时间,time过期时间单位:秒;
|
||||
// name表示自定义的缓存名称,便于业务层精准定位缓存项(如果业务层需要手动清理时,必须指定缓存名称),
|
||||
// 例如:查询缓存时设置名称,清理缓存时可以给定清理的缓存名称进行精准清理。
|
||||
func (md *Model) Cache(time int, name ... string) *Model {
|
||||
model := md.Clone()
|
||||
model.cacheTime = time
|
||||
if len(name) > 0 {
|
||||
model.cacheName = name[0]
|
||||
}
|
||||
// 查询缓存特性不支持事务操作
|
||||
if model.tx == nil {
|
||||
model.cacheEnabled = true
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,操作数据记录项,可以是string/Map, 也可以是:key,value,key,value,...
|
||||
func (md *Model) Data(data ...interface{}) (*Model) {
|
||||
model := md.Clone()
|
||||
if len(data) > 1 {
|
||||
m := make(map[string]interface{})
|
||||
for i := 0; i < len(data); i += 2 {
|
||||
m[gconv.String(data[i])] = data[i+1]
|
||||
}
|
||||
md.data = m
|
||||
model.data = m
|
||||
} else {
|
||||
md.data = data[0]
|
||||
switch data[0].(type) {
|
||||
case List:
|
||||
model.data = data[0]
|
||||
case Map:
|
||||
model.data = data[0]
|
||||
default:
|
||||
rv := reflect.ValueOf(data[0])
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
case reflect.Slice: fallthrough
|
||||
case reflect.Array:
|
||||
list := make(List, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
list[i] = gconv.Map(rv.Index(i).Interface())
|
||||
}
|
||||
model.data = list
|
||||
case reflect.Map:
|
||||
model.data = gconv.Map(data[0])
|
||||
default:
|
||||
model.data = data[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
return md
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作, CURD - Insert/BatchInsert
|
||||
@ -163,16 +256,24 @@ func (md *Model) Insert() (result sql.Result, err error) {
|
||||
if md.batch > 0 {
|
||||
batch = md.batch
|
||||
}
|
||||
if md.filter {
|
||||
for k, m := range list {
|
||||
list[k] = md.db.filterFields(md.tables, m)
|
||||
}
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.BatchInsert(md.tables, list, batch)
|
||||
} else {
|
||||
return md.tx.BatchInsert(md.tables, list, batch)
|
||||
}
|
||||
} else if dataMap, ok := md.data.(Map); ok {
|
||||
} else if data, ok := md.data.(Map); ok {
|
||||
if md.filter {
|
||||
data = md.db.filterFields(md.tables, data)
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Insert(md.tables, dataMap)
|
||||
return md.db.Insert(md.tables, data)
|
||||
} else {
|
||||
return md.tx.Insert(md.tables, dataMap)
|
||||
return md.tx.Insert(md.tables, data)
|
||||
}
|
||||
}
|
||||
return nil, errors.New("inserting into table with invalid data type")
|
||||
@ -194,16 +295,24 @@ func (md *Model) Replace() (result sql.Result, err error) {
|
||||
if md.batch > 0 {
|
||||
batch = md.batch
|
||||
}
|
||||
if md.filter {
|
||||
for k, m := range list {
|
||||
list[k] = md.db.filterFields(md.tables, m)
|
||||
}
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.BatchReplace(md.tables, list, batch)
|
||||
} else {
|
||||
return md.tx.BatchReplace(md.tables, list, batch)
|
||||
}
|
||||
} else if dataMap, ok := md.data.(Map); ok {
|
||||
} else if data, ok := md.data.(Map); ok {
|
||||
if md.filter {
|
||||
data = md.db.filterFields(md.tables, data)
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Insert(md.tables, dataMap)
|
||||
return md.db.Replace(md.tables, data)
|
||||
} else {
|
||||
return md.tx.Insert(md.tables, dataMap)
|
||||
return md.tx.Replace(md.tables, data)
|
||||
}
|
||||
}
|
||||
return nil, errors.New("replacing into table with invalid data type")
|
||||
@ -225,16 +334,24 @@ func (md *Model) Save() (result sql.Result, err error) {
|
||||
if md.batch > 0 {
|
||||
batch = md.batch
|
||||
}
|
||||
if md.filter {
|
||||
for k, m := range list {
|
||||
list[k] = md.db.filterFields(md.tables, m)
|
||||
}
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.BatchSave(md.tables, list, batch)
|
||||
} else {
|
||||
return md.tx.BatchSave(md.tables, list, batch)
|
||||
}
|
||||
} else if dataMap, ok := md.data.(Map); ok {
|
||||
} else if data, ok := md.data.(Map); ok {
|
||||
if md.filter {
|
||||
data = md.db.filterFields(md.tables, data)
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Save(md.tables, dataMap)
|
||||
return md.db.Save(md.tables, data)
|
||||
} else {
|
||||
return md.tx.Save(md.tables, dataMap)
|
||||
return md.tx.Save(md.tables, data)
|
||||
}
|
||||
}
|
||||
return nil, errors.New("saving into table with invalid data type")
|
||||
@ -250,6 +367,13 @@ func (md *Model) Update() (result sql.Result, err error) {
|
||||
if md.data == nil {
|
||||
return nil, errors.New("updating table with empty data")
|
||||
}
|
||||
if md.filter {
|
||||
if data, ok := md.data.(Map); ok {
|
||||
if md.filter {
|
||||
md.data = md.db.filterFields(md.tables, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Update(md.tables, md.data, md.where, md.whereArgs ...)
|
||||
} else {
|
||||
@ -264,9 +388,6 @@ func (md *Model) Delete() (result sql.Result, err error) {
|
||||
md.checkAndRemoveCache()
|
||||
}
|
||||
}()
|
||||
if md.where == "" {
|
||||
return nil, errors.New("where is required while deleting")
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Delete(md.tables, md.where, md.whereArgs...)
|
||||
} else {
|
||||
@ -274,36 +395,14 @@ func (md *Model) Delete() (result sql.Result, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 设置批处理的大小
|
||||
func (md *Model) Batch(batch int) *Model {
|
||||
md.batch = batch
|
||||
return md
|
||||
}
|
||||
|
||||
// 查询缓存/清除缓存操作,需要注意的是,事务查询不支持缓存。
|
||||
// 当time < 0时表示清除缓存, time=0时表示不过期, time > 0时表示过期时间,time过期时间单位:秒;
|
||||
// name表示自定义的缓存名称,便于业务层精准定位缓存项(如果业务层需要手动清理时,必须指定缓存名称),
|
||||
// 例如:查询缓存时设置名称,清理缓存时可以给定清理的缓存名称进行精准清理。
|
||||
func (md *Model) Cache(time int, name ... string) *Model {
|
||||
md.cacheTime = time
|
||||
if len(name) > 0 {
|
||||
md.cacheName = name[0]
|
||||
}
|
||||
// 查询缓存特性不支持事务操作
|
||||
if md.tx == nil {
|
||||
md.cacheEnabled = true
|
||||
}
|
||||
return md
|
||||
}
|
||||
|
||||
// 链式操作,select
|
||||
func (md *Model) Select() (Result, error) {
|
||||
return md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
return md.All()
|
||||
}
|
||||
|
||||
// 链式操作,查询所有记录
|
||||
func (md *Model) All() (Result, error) {
|
||||
return md.Select()
|
||||
return md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
}
|
||||
|
||||
// 链式操作,查询单条记录
|
||||
@ -342,6 +441,9 @@ func (md *Model) Struct(obj interface{}) error {
|
||||
// 链式操作,查询数量,fields可以为空,也可以自定义查询字段,
|
||||
// 当给定自定义查询字段时,该字段必须为数量结果,否则会引起歧义,使用如:md.Fields("COUNT(id)")
|
||||
func (md *Model) Count() (int, error) {
|
||||
defer func(fields string) {
|
||||
md.fields = fields
|
||||
}(md.fields)
|
||||
if md.fields == "" || md.fields == "*" {
|
||||
md.fields = "COUNT(1)"
|
||||
} else {
|
||||
@ -364,29 +466,30 @@ func (md *Model) Count() (int, error) {
|
||||
}
|
||||
|
||||
// 查询操作,对底层SQL操作的封装
|
||||
func (md *Model) getAll(sql string, args ...interface{}) (result Result, err error) {
|
||||
var cacheKey string
|
||||
func (md *Model) getAll(query string, args ...interface{}) (result Result, err error) {
|
||||
cacheKey := ""
|
||||
// 查询缓存查询处理
|
||||
if md.cacheEnabled {
|
||||
cacheKey = md.cacheName
|
||||
if len(cacheKey) == 0 {
|
||||
cacheKey = sql + "/" + gconv.String(args)
|
||||
cacheKey = query + "/" + gconv.String(args)
|
||||
}
|
||||
if v := md.db.cache.Get(cacheKey); v != nil {
|
||||
if v := md.db.getCache().Get(cacheKey); v != nil {
|
||||
return v.(Result), nil
|
||||
}
|
||||
}
|
||||
|
||||
if md.tx == nil {
|
||||
result, err = md.db.GetAll(sql, args...)
|
||||
result, err = md.db.GetAll(query, args...)
|
||||
} else {
|
||||
result, err = md.tx.GetAll(sql, args...)
|
||||
result, err = md.tx.GetAll(query, args...)
|
||||
}
|
||||
// 查询缓存保存处理
|
||||
if len(cacheKey) > 0 && err == nil {
|
||||
if md.cacheTime < 0 {
|
||||
md.db.cache.Remove(cacheKey)
|
||||
md.db.getCache().Remove(cacheKey)
|
||||
} else {
|
||||
md.db.cache.Set(cacheKey, result, md.cacheTime*1000)
|
||||
md.db.getCache().Set(cacheKey, result, md.cacheTime*1000)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
@ -395,7 +498,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err
|
||||
// 检查是否需要查询查询缓存
|
||||
func (md *Model) checkAndRemoveCache() {
|
||||
if md.cacheEnabled && md.cacheTime < 0 && len(md.cacheName) > 0 {
|
||||
md.db.cache.Remove(md.cacheName)
|
||||
md.db.getCache().Remove(md.cacheName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -424,11 +527,10 @@ func (md *Model) getFormattedSql() string {
|
||||
// @author ymrjqyy
|
||||
// @author 2018-08-15
|
||||
func (md *Model) Chunk(limit int, callback func(result Result, err error) bool) {
|
||||
var page = 1
|
||||
page := 1
|
||||
for {
|
||||
md.ForPage(page, limit)
|
||||
sqls := md.getFormattedSql()
|
||||
data, err := md.getAll(sqls, md.whereArgs...)
|
||||
data, err := md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
if err != nil {
|
||||
callback(nil, err)
|
||||
break
|
||||
|
||||
@ -22,20 +22,19 @@ import (
|
||||
)
|
||||
|
||||
|
||||
var linkMssql = &dbmssql{}
|
||||
|
||||
// 数据库链接对象
|
||||
type dbmssql struct {
|
||||
Db
|
||||
type dbMssql struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
// 创建SQL操作对象
|
||||
func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
source := ""
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
|
||||
config.User, config.Pass, config.Host, config.Port, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("sqlserver", source); err == nil {
|
||||
return db, nil
|
||||
@ -44,43 +43,38 @@ func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 左
|
||||
func (db *dbmssql) getQuoteCharLeft() string {
|
||||
return "\""
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 右
|
||||
func (db *dbmssql) getQuoteCharRight() string {
|
||||
return "\""
|
||||
// 获得关键字操作符
|
||||
func (db *dbMssql) getChars () (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// 在执行sql之前对sql进行进一步处理
|
||||
func (db *dbmssql) handleSqlBeforeExec(q *string) *string {
|
||||
func (db *dbMssql) handleSqlBeforeExec(query string) string {
|
||||
index := 0
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string {
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
index++
|
||||
return fmt.Sprintf("@p%d", index)
|
||||
})
|
||||
|
||||
str, _ = gregex.ReplaceString("\"", "", str)
|
||||
|
||||
return db.parseSql(&str)
|
||||
return db.parseSql(str)
|
||||
}
|
||||
|
||||
//将MYSQL的SQL语法转换为MSSQL的语法
|
||||
//1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换
|
||||
func (db *dbmssql) parseSql(sql *string) *string {
|
||||
func (db *dbMssql) parseSql(sql string) string {
|
||||
//下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出
|
||||
patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
|
||||
if gregex.IsMatchString(patten, *sql) == false {
|
||||
if gregex.IsMatchString(patten, sql) == false {
|
||||
fmt.Println("not matched..")
|
||||
return sql
|
||||
}
|
||||
|
||||
res, err := gregex.MatchAllString(patten, *sql)
|
||||
res, err := gregex.MatchAllString(patten, sql)
|
||||
if err != nil {
|
||||
fmt.Println("MatchString error.", err)
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
index := 0
|
||||
@ -96,17 +90,17 @@ func (db *dbmssql) parseSql(sql *string) *string {
|
||||
}
|
||||
|
||||
//不含LIMIT则不处理
|
||||
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false {
|
||||
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
|
||||
break
|
||||
}
|
||||
|
||||
//判断SQL中是否含有order by
|
||||
selectStr := ""
|
||||
orderbyStr := ""
|
||||
haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql)
|
||||
haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
|
||||
if haveOrderby {
|
||||
//取order by 前面的字符串
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql)
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
|
||||
|
||||
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "ORDER BY") == false{
|
||||
break
|
||||
@ -114,13 +108,13 @@ func (db *dbmssql) parseSql(sql *string) *string {
|
||||
selectStr = queryExpr[2]
|
||||
|
||||
//取order by表达式的值
|
||||
orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", *sql)
|
||||
orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql)
|
||||
if len(orderbyExpr) != 4 || strings.EqualFold(orderbyExpr[1], "ORDER BY") == false || strings.EqualFold(orderbyExpr[3], "LIMIT") == false{
|
||||
break
|
||||
}
|
||||
orderbyStr = orderbyExpr[2]
|
||||
} else {
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql)
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
|
||||
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{
|
||||
break
|
||||
}
|
||||
@ -144,14 +138,14 @@ func (db *dbmssql) parseSql(sql *string) *string {
|
||||
}
|
||||
|
||||
if haveOrderby {
|
||||
*sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit)
|
||||
sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit)
|
||||
} else {
|
||||
if first == 0 {
|
||||
first = limit
|
||||
} else {
|
||||
first = limit - first
|
||||
}
|
||||
*sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr)
|
||||
sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
@ -12,22 +12,19 @@ import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// MySQL接口对象
|
||||
var linkMysql = &dbmysql{}
|
||||
|
||||
|
||||
// 数据库链接对象
|
||||
type dbmysql struct {
|
||||
Db
|
||||
type dbMysql struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
// 创建SQL操作对象,内部采用了lazy link处理
|
||||
func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true",
|
||||
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset)
|
||||
}
|
||||
if db, err := sql.Open("mysql", source); err == nil {
|
||||
return db, nil
|
||||
@ -36,17 +33,12 @@ func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 左
|
||||
func (db *dbmysql) getQuoteCharLeft () string {
|
||||
return "`"
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 右
|
||||
func (db *dbmysql) getQuoteCharRight () string {
|
||||
return "`"
|
||||
// 获得关键字操作符
|
||||
func (db *dbMysql) getChars () (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
}
|
||||
|
||||
// 在执行sql之前对sql进行进一步处理
|
||||
func (db *dbmysql) handleSqlBeforeExec(q *string) *string {
|
||||
return q
|
||||
func (db *dbMysql) handleSqlBeforeExec(query string) string {
|
||||
return query
|
||||
}
|
||||
@ -21,20 +21,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var linkOracle = &dboracle{}
|
||||
|
||||
// 数据库链接对象
|
||||
type dboracle struct {
|
||||
Db
|
||||
type dbOracle struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
// 创建SQL操作对象
|
||||
func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name)
|
||||
source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("oci8", source); err == nil {
|
||||
return db, nil
|
||||
@ -43,42 +41,37 @@ func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 左
|
||||
func (db *dboracle) getQuoteCharLeft() string {
|
||||
return "\""
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 右
|
||||
func (db *dboracle) getQuoteCharRight() string {
|
||||
return "\""
|
||||
// 获得关键字操作符
|
||||
func (db *dbOracle) getChars () (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// 在执行sql之前对sql进行进一步处理
|
||||
func (db *dboracle) handleSqlBeforeExec(q *string) *string {
|
||||
func (db *dbOracle) handleSqlBeforeExec(query string) string {
|
||||
index := 0
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string {
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
index++
|
||||
return fmt.Sprintf(":%d", index)
|
||||
})
|
||||
|
||||
str, _ = gregex.ReplaceString("\"", "", str)
|
||||
|
||||
return db.parseSql(&str)
|
||||
return db.parseSql(str)
|
||||
}
|
||||
|
||||
//由于ORACLE中对LIMIT和批量插入的语法与MYSQL不一致,所以这里需要对LIMIT和批量插入做语法上的转换
|
||||
func (db *dboracle) parseSql(sql *string) *string {
|
||||
func (db *dbOracle) parseSql(sql string) string {
|
||||
//下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出
|
||||
patten := `^\s*(?i)(SELECT)|(INSERT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
|
||||
if gregex.IsMatchString(patten, *sql) == false {
|
||||
if gregex.IsMatchString(patten, sql) == false {
|
||||
fmt.Println("not matched..")
|
||||
return sql
|
||||
}
|
||||
|
||||
res, err := gregex.MatchAllString(patten, *sql)
|
||||
res, err := gregex.MatchAllString(patten, sql)
|
||||
if err != nil {
|
||||
fmt.Println("MatchString error.", err)
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
index := 0
|
||||
@ -94,11 +87,11 @@ func (db *dboracle) parseSql(sql *string) *string {
|
||||
}
|
||||
|
||||
//取limit前面的字符串
|
||||
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false {
|
||||
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
|
||||
break
|
||||
}
|
||||
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql)
|
||||
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
|
||||
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{
|
||||
break
|
||||
}
|
||||
@ -118,10 +111,10 @@ func (db *dboracle) parseSql(sql *string) *string {
|
||||
}
|
||||
|
||||
//也可以使用between,据说这种写法的性能会比between好点,里层SQL中的ROWNUM_ >= limit可以缩小查询后的数据集规模
|
||||
*sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first)
|
||||
sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first)
|
||||
case "INSERT":
|
||||
//获取VALUE的值,匹配所有带括号的值,会将INSERT INTO后的值匹配到,所以下面的判断语句会判断数组长度是否小于3
|
||||
valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, *sql)
|
||||
valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, sql)
|
||||
if err != nil {
|
||||
return sql
|
||||
}
|
||||
@ -132,17 +125,17 @@ func (db *dboracle) parseSql(sql *string) *string {
|
||||
}
|
||||
|
||||
//获取INTO后面的值
|
||||
tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, *sql)
|
||||
tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, sql)
|
||||
if err != nil {
|
||||
return sql
|
||||
}
|
||||
tableExpr[0] = strings.TrimSpace(tableExpr[0])
|
||||
|
||||
*sql = "INSERT ALL"
|
||||
sql = "INSERT ALL"
|
||||
for i := 1; i < len(valueExpr); i++ {
|
||||
*sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0]))
|
||||
sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0]))
|
||||
}
|
||||
*sql += " SELECT 1 FROM DUAL"
|
||||
sql += " SELECT 1 FROM DUAL"
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
@ -18,22 +18,18 @@ import (
|
||||
// _ "gitee.com/johng/gf/third/github.com/lib/pq"
|
||||
// @todo 需要完善replace和save的操作覆盖
|
||||
|
||||
// PostgreSQL接口对象
|
||||
var linkPgsql = &dbpgsql{}
|
||||
|
||||
|
||||
// 数据库链接对象
|
||||
type dbpgsql struct {
|
||||
Db
|
||||
type dbPgsql struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
// 创建SQL操作对象,内部采用了lazy link处理
|
||||
func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("postgres", source); err == nil {
|
||||
return db, nil
|
||||
@ -42,23 +38,18 @@ func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 左
|
||||
func (db *dbpgsql) getQuoteCharLeft () string {
|
||||
return "\""
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 右
|
||||
func (db *dbpgsql) getQuoteCharRight () string {
|
||||
return "\""
|
||||
// 获得关键字操作符
|
||||
func (db *dbPgsql) getChars () (charLeft string, charRight string) {
|
||||
return "\"", "\""
|
||||
}
|
||||
|
||||
// 在执行sql之前对sql进行进一步处理
|
||||
func (db *dbpgsql) handleSqlBeforeExec(q *string) *string {
|
||||
func (db *dbPgsql) handleSqlBeforeExec(query string) string {
|
||||
reg := regexp.MustCompile("\\?")
|
||||
index := 0
|
||||
str := reg.ReplaceAllStringFunc(*q, func (s string) string {
|
||||
str := reg.ReplaceAllStringFunc(query, func (s string) string {
|
||||
index ++
|
||||
return fmt.Sprintf("$%d", index)
|
||||
})
|
||||
return &str
|
||||
return str
|
||||
}
|
||||
@ -16,20 +16,18 @@ import (
|
||||
|
||||
// Sqlite接口对象
|
||||
// @author wxkj<wxscz@qq.com>
|
||||
var linkSqlite = &dbsqlite{}
|
||||
|
||||
|
||||
// 数据库链接对象
|
||||
type dbsqlite struct {
|
||||
Db
|
||||
type dbSqlite struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = c.Name
|
||||
source = config.Name
|
||||
}
|
||||
if db, err := sql.Open("sqlite3", source); err == nil {
|
||||
return db, nil
|
||||
@ -38,20 +36,14 @@ func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 左
|
||||
func (db *dbsqlite) getQuoteCharLeft() string {
|
||||
return "`"
|
||||
}
|
||||
|
||||
// 获得关键字操作符 - 右
|
||||
func (db *dbsqlite) getQuoteCharRight() string {
|
||||
return "`"
|
||||
// 获得关键字操作符
|
||||
func (db *dbSqlite) getChars () (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
}
|
||||
|
||||
// 在执行sql之前对sql进行进一步处理
|
||||
// @todo 需要增加对Save方法的支持,可使用正则来实现替换,
|
||||
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
|
||||
func (db *dbsqlite) handleSqlBeforeExec(q *string) *string {
|
||||
|
||||
return q
|
||||
func (db *dbSqlite) handleSqlBeforeExec(query string) string {
|
||||
return query
|
||||
}
|
||||
@ -7,128 +7,55 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"strings"
|
||||
"reflect"
|
||||
"database/sql"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
)
|
||||
|
||||
// 数据库事务对象
|
||||
type Tx struct {
|
||||
db *Db
|
||||
type TX struct {
|
||||
db DB
|
||||
tx *sql.Tx
|
||||
master *sql.DB
|
||||
}
|
||||
|
||||
// 事务操作,提交
|
||||
func (tx *Tx) Commit() error {
|
||||
func (tx *TX) Commit() error {
|
||||
return tx.tx.Commit()
|
||||
}
|
||||
|
||||
// 事务操作,回滚
|
||||
func (tx *Tx) Rollback() error {
|
||||
func (tx *TX) Rollback() error {
|
||||
return tx.tx.Rollback()
|
||||
}
|
||||
|
||||
// (事务)数据库sql查询操作,主要执行查询
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
var err error
|
||||
var rows *sql.Rows
|
||||
p := tx.db.link.handleSqlBeforeExec(&query)
|
||||
if tx.db.debug.Val() {
|
||||
militime1 := gtime.Millisecond()
|
||||
rows, err = tx.tx.Query(*p, args ...)
|
||||
militime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : *p,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : militime1,
|
||||
End : militime2,
|
||||
Func : "TX:Query",
|
||||
}
|
||||
tx.db.sqls.Put(s)
|
||||
tx.db.printSql(s)
|
||||
} else {
|
||||
rows, err = tx.tx.Query(*p, args ...)
|
||||
}
|
||||
if err == nil {
|
||||
return rows, nil
|
||||
} else {
|
||||
err = tx.db.formatError(err, p, args...)
|
||||
}
|
||||
return nil, err
|
||||
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
return tx.db.doQuery(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// (事务)执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
var err error
|
||||
var result sql.Result
|
||||
p := tx.db.link.handleSqlBeforeExec(&query)
|
||||
if tx.db.debug.Val() {
|
||||
militime1 := gtime.Millisecond()
|
||||
result, err = tx.tx.Exec(*p, args ...)
|
||||
militime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : *p,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : militime1,
|
||||
End : militime2,
|
||||
Func : "TX:Exec",
|
||||
}
|
||||
tx.db.sqls.Put(s)
|
||||
tx.db.printSql(s)
|
||||
} else {
|
||||
result, err = tx.tx.Exec(*p, args ...)
|
||||
}
|
||||
return result, tx.db.formatError(err, p, args...)
|
||||
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.doExec(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.db.doPrepare(tx.tx, query)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果集,以列表结构返回
|
||||
func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
// 执行sql
|
||||
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
rows, err := tx.Query(query, args ...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
// 列名称列表
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回结构组装
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
scanArgs := make([]interface{}, len(values))
|
||||
records := make(Result, 0)
|
||||
for i := range values {
|
||||
scanArgs[i] = &values[i]
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
row := make(Record)
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, col := range values {
|
||||
v := make([]byte, len(col))
|
||||
copy(v, col)
|
||||
row[columns[i]] = gvar.New(v, false)
|
||||
}
|
||||
//fmt.Printf("%p\n", row["typeid"])
|
||||
records = append(records, row)
|
||||
}
|
||||
return records, nil
|
||||
defer rows.Close()
|
||||
return rowsToResult(rows)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,以关联数组结构返回
|
||||
func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := tx.GetAll(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -140,7 +67,7 @@ func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中
|
||||
func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := tx.GetOne(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -148,9 +75,8 @@ func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) erro
|
||||
return one.ToStruct(obj)
|
||||
}
|
||||
|
||||
|
||||
// 数据库查询,获取查询字段值
|
||||
func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := tx.GetOne(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -162,186 +88,55 @@ func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询数量
|
||||
func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) {
|
||||
val, err := tx.GetValue(query, args ...)
|
||||
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
}
|
||||
value, err := tx.GetValue(query, args ...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return gconv.Int(val), nil
|
||||
}
|
||||
|
||||
// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作
|
||||
func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
|
||||
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
|
||||
if condition != nil {
|
||||
s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition))
|
||||
}
|
||||
if len(groupBy) > 0 {
|
||||
s += fmt.Sprintf("GROUP BY %s ", groupBy)
|
||||
}
|
||||
if len(orderBy) > 0 {
|
||||
s += fmt.Sprintf("ORDER BY %s ", orderBy)
|
||||
}
|
||||
if limit > 0 {
|
||||
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
|
||||
}
|
||||
return tx.GetAll(s, args ... )
|
||||
}
|
||||
|
||||
// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.tx.Prepare(query)
|
||||
}
|
||||
|
||||
// insert、replace, save, ignore操作
|
||||
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
|
||||
func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
for k, v := range data {
|
||||
keys = append(keys, tx.db.charl + k + tx.db.charr)
|
||||
values = append(values, "?")
|
||||
params = append(params, v)
|
||||
}
|
||||
operation := tx.db.getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for k, _ := range data {
|
||||
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
|
||||
}
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
return tx.Exec(
|
||||
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(keys, ","),
|
||||
strings.Join(values, ","),
|
||||
updatestr),
|
||||
params...
|
||||
)
|
||||
return value.Int(), nil
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (tx *Tx) Insert(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_INSERT)
|
||||
func (tx *TX) Insert(table string, data Map) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (tx *Tx) Replace(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_REPLACE)
|
||||
func (tx *TX) Replace(table string, data Map) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (tx *Tx) Save(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// 批量写入数据
|
||||
func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var bvalues []string
|
||||
var params []interface{}
|
||||
var result sql.Result
|
||||
var size = len(list)
|
||||
// 判断长度
|
||||
if size < 1 {
|
||||
return result, errors.New("empty data list")
|
||||
}
|
||||
// 首先获取字段名称及记录长度
|
||||
for k, _ := range list[0] {
|
||||
keys = append(keys, k)
|
||||
values = append(values, "?")
|
||||
}
|
||||
keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr
|
||||
valueHolderStr := "(" + strings.Join(values, ",") + ")"
|
||||
// 操作判断
|
||||
operation := tx.db.getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for _, k := range keys {
|
||||
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
|
||||
}
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
// 构造批量写入数据格式(注意map的遍历是无序的)
|
||||
for i := 0; i < size; i++ {
|
||||
for _, k := range keys {
|
||||
params = append(params, list[i][k])
|
||||
}
|
||||
bvalues = append(bvalues, valueHolderStr)
|
||||
if len(bvalues) == batch {
|
||||
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result = r
|
||||
params = params[:0]
|
||||
bvalues = bvalues[:0]
|
||||
}
|
||||
}
|
||||
// 处理最后不构成指定批量的数据
|
||||
if len(bvalues) > 0 {
|
||||
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result = r
|
||||
}
|
||||
return result, nil
|
||||
func (tx *TX) Save(table string, data Map) (sql.Result, error) {
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_INSERT)
|
||||
func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_REPLACE)
|
||||
func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_SAVE)
|
||||
func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
var params []interface{}
|
||||
var updates string
|
||||
refValue := reflect.ValueOf(data)
|
||||
if refValue.Kind() == reflect.Map {
|
||||
var fields []string
|
||||
keys := refValue.MapKeys()
|
||||
for _, k := range keys {
|
||||
fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr))
|
||||
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
|
||||
updates = strings.Join(fields, ",")
|
||||
}
|
||||
} else {
|
||||
updates = gconv.String(data)
|
||||
}
|
||||
for _, v := range args {
|
||||
params = append(params, gconv.String(v))
|
||||
}
|
||||
return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...)
|
||||
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.doUpdate(tx.tx, table, data, condition, args ...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...)
|
||||
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.doDelete(tx.tx, table, condition, args ...)
|
||||
}
|
||||
|
||||
|
||||
52
g/database/gdb/gdb_unit_0_test.go
Normal file
52
g/database/gdb/gdb_unit_0_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数据库对象/接口
|
||||
db gdb.DB
|
||||
)
|
||||
|
||||
// 初始化连接参数。
|
||||
// 测试前需要修改连接参数。
|
||||
func init() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode{
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Pass: "",
|
||||
Name: "",
|
||||
Type: "mysql",
|
||||
Role: "master",
|
||||
Charset: "utf8",
|
||||
Priority: 1,
|
||||
})
|
||||
if r, err := gdb.New(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
db = r
|
||||
}
|
||||
// 准备测试数据结构
|
||||
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `test` CHARACTER SET UTF8"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
db.SetSchema("test")
|
||||
if _, err := db.Exec("DROP TABLE IF EXISTS `user`"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec(`
|
||||
CREATE TABLE user (
|
||||
id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID',
|
||||
passport varchar(45) NOT NULL COMMENT '账号',
|
||||
password char(32) NOT NULL COMMENT '密码',
|
||||
nickname varchar(45) NOT NULL COMMENT '昵称',
|
||||
create_time timestamp NOT NULL COMMENT '创建时间/注册时间',
|
||||
PRIMARY KEY (id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
`); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
172
g/database/gdb/gdb_unit_1_test.go
Normal file
172
g/database/gdb/gdb_unit_1_test.go
Normal file
@ -0,0 +1,172 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDbBase_Query(t *testing.T) {
|
||||
if _, err := db.Query("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Query("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Exec(t *testing.T) {
|
||||
if _, err := db.Exec("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Prepare(t *testing.T) {
|
||||
st, err := db.Prepare("SELECT 100")
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
rows, err := st.Query()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
array, err := rows.Columns()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(array[0], "100")
|
||||
if err := rows.Close(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Insert(t *testing.T) {
|
||||
if _, err := db.Insert("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_BatchInsert(t *testing.T) {
|
||||
if _, err := db.BatchInsert("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Save(t *testing.T) {
|
||||
if _, err := db.Save("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Replace(t *testing.T) {
|
||||
if _, err := db.Save("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T111",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Update(t *testing.T) {
|
||||
if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetAll(t *testing.T) {
|
||||
if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(len(result), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetOne(t *testing.T) {
|
||||
if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T111")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetValue(t *testing.T) {
|
||||
if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.Int(), 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetCount(t *testing.T) {
|
||||
if count, err := db.GetCount("SELECT * FROM user"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(count, 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetStruct(t *testing.T) {
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Delete(t *testing.T) {
|
||||
if result, err := db.Delete("user", nil); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
}
|
||||
|
||||
220
g/database/gdb/gdb_unit_2_test.go
Normal file
220
g/database/gdb/gdb_unit_2_test.go
Normal file
@ -0,0 +1,220 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModel_Insert(t *testing.T) {
|
||||
result, err := db.Table("user").Filter().Data(g.Map{
|
||||
"id" : 1,
|
||||
"uid" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Insert()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.LastInsertId()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
|
||||
func TestModel_Batch(t *testing.T) {
|
||||
result, err := db.Table("user").Filter().Data(g.List{
|
||||
{
|
||||
"id" : 2,
|
||||
"uid" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"uid" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}).Batch(10).Insert()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Replace(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Replace()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Save(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t111",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T111",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Save()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Update(t *testing.T) {
|
||||
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
|
||||
func TestModel_Clone(t *testing.T) {
|
||||
md := db.Table("user").Where("id IN(?)", g.Slice{1,3})
|
||||
count, err := md.Count()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
record, err := md.OrderBy("id DESC").One()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
result, err := md.OrderBy("id ASC").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(count, 2)
|
||||
gtest.Assert(record["id"].Int(), 3)
|
||||
gtest.Assert(len(result), 2)
|
||||
gtest.Assert(result[0]["id"].Int(), 1)
|
||||
gtest.Assert(result[1]["id"].Int(), 3)
|
||||
}
|
||||
|
||||
func TestModel_All(t *testing.T) {
|
||||
result, err := db.Table("user").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
}
|
||||
|
||||
func TestModel_One(t *testing.T) {
|
||||
record, err := db.Table("user").Where("id", 1).One()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Value(t *testing.T) {
|
||||
value, err := db.Table("user").Fields("nickname").Where("id", 1).Value()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(value.String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Count(t *testing.T) {
|
||||
count, err := db.Table("user").Count()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(count, 3)
|
||||
}
|
||||
|
||||
func TestModel_Select(t *testing.T) {
|
||||
result, err := db.Table("user").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
}
|
||||
|
||||
func TestModel_Struct(t *testing.T) {
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
err := db.Table("user").Where("id=1").Struct(user)
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(user.NickName, "T111")
|
||||
}
|
||||
|
||||
func TestModel_OrderBy(t *testing.T) {
|
||||
result, err := db.Table("user").OrderBy("id DESC").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
gtest.Assert(result[0]["nickname"].String(), "T3")
|
||||
}
|
||||
|
||||
func TestModel_GroupBy(t *testing.T) {
|
||||
result, err := db.Table("user").GroupBy("id").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
gtest.Assert(result[0]["nickname"].String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Where1(t *testing.T) {
|
||||
result, err := db.Table("user").Where("id IN(?)", g.Slice{1,3}).OrderBy("id ASC").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 2)
|
||||
gtest.Assert(result[0]["id"].Int(), 1)
|
||||
gtest.Assert(result[1]["id"].Int(), 3)
|
||||
}
|
||||
|
||||
func TestModel_Where2(t *testing.T) {
|
||||
result, err := db.Table("user").Where("nickname=? AND id IN(?)", "T3", g.Slice{1,3}).OrderBy("id ASC").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 1)
|
||||
gtest.Assert(result[0]["id"].Int(), 3)
|
||||
}
|
||||
|
||||
func TestModel_Delete(t *testing.T) {
|
||||
result, err := db.Table("user").Delete()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
|
||||
|
||||
372
g/database/gdb/gdb_unit_3_test.go
Normal file
372
g/database/gdb/gdb_unit_3_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTX_Query(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if rows, err := tx.Query("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
rows.Close()
|
||||
}
|
||||
if _, err := tx.Query("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Exec(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Exec("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Exec("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Commit(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Rollback(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Prepare(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
st, err := tx.Prepare("SELECT 100")
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
rows, err := st.Query()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
array, err := rows.Columns()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(array[0], "100")
|
||||
if err := rows.Close(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Insert(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Insert("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchInsert(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchInsert("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchReplace(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchReplace("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "p2",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 4,
|
||||
"passport" : "t4",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T4",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
// 数据数量
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 4)
|
||||
}
|
||||
// 检查replace后的数值
|
||||
if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "p2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchSave(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchSave("user", g.List {
|
||||
{
|
||||
"id" : 4,
|
||||
"passport" : "t4",
|
||||
"password" : "p4",
|
||||
"nickname" : "T4",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
// 数据数量
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 4)
|
||||
}
|
||||
// 检查replace后的数值
|
||||
if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "p4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Replace(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Replace("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "T1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Save(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Save("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "T11")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetAll(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(len(result), 1)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetOne(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T2")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetValue(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.Int(), 3)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetCount(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if count, err := tx.GetCount("SELECT * FROM user"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(count, 4)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetStruct(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(user.NickName, "T11")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Delete(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Delete("user", nil); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -4,16 +4,15 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
// Kafka Client.
|
||||
// Package gkafka provides producer and consumer client for kafka server/Kafka客户端.
|
||||
package gkafka
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"time"
|
||||
"strings"
|
||||
"gitee.com/johng/gf/third/github.com/Shopify/sarama"
|
||||
"gitee.com/johng/gf/third/github.com/johng-cn/sarama-cluster"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -177,8 +176,6 @@ func (client *Client) Receive() (*Message, error) {
|
||||
case <-notifyChan:
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown error")
|
||||
}
|
||||
|
||||
// Send data to kafka in synchronized way.
|
||||
|
||||
@ -487,7 +487,6 @@ func (j *Json) convertValue(value interface{}) interface{} {
|
||||
v, _ := Decode(b)
|
||||
return v
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// 用于Set方法中,对指针指向的内存地址进行赋值
|
||||
|
||||
@ -9,18 +9,18 @@
|
||||
package gins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/gmap"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g/database/gredis"
|
||||
"gitee.com/johng/gf/g/os/gcfg"
|
||||
"gitee.com/johng/gf/g/os/gcmd"
|
||||
"gitee.com/johng/gf/g/os/genv"
|
||||
"gitee.com/johng/gf/g/os/gfile"
|
||||
"gitee.com/johng/gf/g/os/gfsnotify"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gview"
|
||||
"gitee.com/johng/gf/g/os/gfile"
|
||||
"gitee.com/johng/gf/g/container/gmap"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g/os/gfsnotify"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/database/gredis"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
)
|
||||
|
||||
@ -76,9 +76,7 @@ func View(name...string) *gview.View {
|
||||
if path == "" {
|
||||
path = genv.Get("GF_VIEWPATH")
|
||||
if path == "" {
|
||||
if gfile.SelfDir() != gfile.TempDir() {
|
||||
path = gfile.SelfDir()
|
||||
}
|
||||
path = gfile.SelfDir()
|
||||
}
|
||||
}
|
||||
view := gview.New(path)
|
||||
@ -105,9 +103,7 @@ func Config(file...string) *gcfg.Config {
|
||||
if path == "" {
|
||||
path = genv.Get("GF_CFGPATH")
|
||||
if path == "" {
|
||||
if gfile.SelfDir() != gfile.TempDir() {
|
||||
path = gfile.SelfDir()
|
||||
}
|
||||
path = gfile.SelfDir()
|
||||
}
|
||||
}
|
||||
config := gcfg.New(path, configFile)
|
||||
@ -120,7 +116,7 @@ func Config(file...string) *gcfg.Config {
|
||||
}
|
||||
|
||||
// 数据库操作对象,使用了连接池
|
||||
func Database(name...string) *gdb.Db {
|
||||
func Database(name...string) gdb.DB {
|
||||
config := Config()
|
||||
group := gdb.DEFAULT_GROUP_NAME
|
||||
if len(name) > 0 {
|
||||
@ -128,65 +124,67 @@ func Database(name...string) *gdb.Db {
|
||||
}
|
||||
key := fmt.Sprintf("%s.%s", gFRAME_CORE_COMPONENT_NAME_DATABASE, group)
|
||||
db := instances.GetOrSetFuncLock(key, func() interface{} {
|
||||
m := config.GetMap("database")
|
||||
if m == nil {
|
||||
glog.Error(`database init failed: "database" node not found, is config file or configuration missing?`)
|
||||
return nil
|
||||
}
|
||||
for group, v := range m {
|
||||
cg := gdb.ConfigGroup{}
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
for _, nodev := range list {
|
||||
node := gdb.ConfigNode{}
|
||||
nodem := nodev.(map[string]interface{})
|
||||
if value, ok := nodem["host"]; ok {
|
||||
node.Host = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["port"]; ok {
|
||||
node.Port = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["user"]; ok {
|
||||
node.User = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["pass"]; ok {
|
||||
node.Pass = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["name"]; ok {
|
||||
node.Name = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["type"]; ok {
|
||||
node.Type = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["role"]; ok {
|
||||
node.Role = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["charset"]; ok {
|
||||
node.Charset = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["priority"]; ok {
|
||||
node.Priority = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["linkinfo"]; ok {
|
||||
node.Linkinfo = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["max-idle"]; ok {
|
||||
node.MaxIdleConnCount = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["max-open"]; ok {
|
||||
node.MaxOpenConnCount = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["max-lifetime"]; ok {
|
||||
node.MaxConnLifetime = gconv.Int(value)
|
||||
}
|
||||
cg = append(cg, node)
|
||||
}
|
||||
if gdb.GetConfig(group) == nil {
|
||||
m := config.GetMap("database")
|
||||
if m == nil {
|
||||
glog.Error(`database init failed: "database" node not found, is config file or configuration missing?`)
|
||||
return nil
|
||||
}
|
||||
gdb.AddConfigGroup(group, cg)
|
||||
for group, v := range m {
|
||||
cg := gdb.ConfigGroup{}
|
||||
if list, ok := v.([]interface{}); ok {
|
||||
for _, nodev := range list {
|
||||
node := gdb.ConfigNode{}
|
||||
nodem := nodev.(map[string]interface{})
|
||||
if value, ok := nodem["host"]; ok {
|
||||
node.Host = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["port"]; ok {
|
||||
node.Port = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["user"]; ok {
|
||||
node.User = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["pass"]; ok {
|
||||
node.Pass = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["name"]; ok {
|
||||
node.Name = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["type"]; ok {
|
||||
node.Type = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["role"]; ok {
|
||||
node.Role = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["charset"]; ok {
|
||||
node.Charset = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["priority"]; ok {
|
||||
node.Priority = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["linkinfo"]; ok {
|
||||
node.Linkinfo = gconv.String(value)
|
||||
}
|
||||
if value, ok := nodem["max-idle"]; ok {
|
||||
node.MaxIdleConnCount = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["max-open"]; ok {
|
||||
node.MaxOpenConnCount = gconv.Int(value)
|
||||
}
|
||||
if value, ok := nodem["max-lifetime"]; ok {
|
||||
node.MaxConnLifetime = gconv.Int(value)
|
||||
}
|
||||
cg = append(cg, node)
|
||||
}
|
||||
}
|
||||
gdb.AddConfigGroup(group, cg)
|
||||
}
|
||||
// 使用gfsnotify进行文件监控,当配置文件有任何变化时,清空数据库配置缓存
|
||||
gfsnotify.Add(config.GetFilePath(), func(event *gfsnotify.Event) {
|
||||
instances.Remove(key)
|
||||
})
|
||||
}
|
||||
// 使用gfsnotify进行文件监控,当配置文件有任何变化时,清空数据库配置缓存
|
||||
gfsnotify.Add(config.GetFilePath(), func(event *gfsnotify.Event) {
|
||||
instances.Remove(key)
|
||||
})
|
||||
if db, err := gdb.New(name...); err == nil {
|
||||
return db
|
||||
} else {
|
||||
@ -195,7 +193,7 @@ func Database(name...string) *gdb.Db {
|
||||
return nil
|
||||
})
|
||||
if db != nil {
|
||||
return db.(*gdb.Db)
|
||||
return db.(gdb.DB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ func (c *Controller) Init(r *ghttp.Request) {
|
||||
}
|
||||
|
||||
// 控制器结束请求接口方法
|
||||
func (c *Controller) Shut(r *ghttp.Request) {
|
||||
func (c *Controller) Shut() {
|
||||
|
||||
}
|
||||
|
||||
|
||||
23
g/g.go
23
g/g.go
@ -10,14 +10,27 @@ package g
|
||||
import "gitee.com/johng/gf/g/container/gvar"
|
||||
|
||||
// 框架动态变量,可以用该类型替代interface{}类型
|
||||
type Var = gvar.Var
|
||||
type Var = gvar.Var
|
||||
|
||||
// 常用map数据结构(使用别名)
|
||||
type Map = map[string]interface{}
|
||||
type Map = map[string]interface{}
|
||||
type MapStrStr = map[string]string
|
||||
type MapStrInt = map[string]int
|
||||
type MapIntStr = map[int]string
|
||||
type MapIntInt = map[int]int
|
||||
|
||||
// 常用list数据结构(使用别名)
|
||||
type List = []Map
|
||||
type List = []Map
|
||||
type ListStrStr = []map[string]string
|
||||
type ListStrInt = []map[string]int
|
||||
type ListIntStr = []map[int]string
|
||||
type ListIntInt = []map[int]int
|
||||
|
||||
|
||||
// 常用slice数据结构(使用别名)
|
||||
type Slice = []interface{}
|
||||
type Array = Slice
|
||||
type Slice = []interface{}
|
||||
type SliceStr = []string
|
||||
type SliceInt = []int
|
||||
type Array = Slice
|
||||
type ArrayStr = SliceStr
|
||||
type ArrayInt = SliceInt
|
||||
|
||||
@ -23,12 +23,12 @@ func Server(name...interface{}) *ghttp.Server {
|
||||
}
|
||||
|
||||
// TCPServer单例对象
|
||||
func TcpServer(name...interface{}) *gtcp.Server {
|
||||
func TCPServer(name...interface{}) *gtcp.Server {
|
||||
return gtcp.GetServer(name...)
|
||||
}
|
||||
|
||||
// UDPServer单例对象
|
||||
func UdpServer(name...interface{}) *gudp.Server {
|
||||
func UDPServer(name...interface{}) *gudp.Server {
|
||||
return gudp.GetServer(name...)
|
||||
}
|
||||
|
||||
@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config {
|
||||
}
|
||||
|
||||
// 数据库操作对象,使用了连接池
|
||||
func Database(name...string) *gdb.Db {
|
||||
func Database(name...string) gdb.DB {
|
||||
return gins.Database(name...)
|
||||
}
|
||||
|
||||
// (别名)Database
|
||||
func DB(name...string) *gdb.Db {
|
||||
func DB(name...string) gdb.DB {
|
||||
return gins.Database(name...)
|
||||
}
|
||||
|
||||
|
||||
@ -10,5 +10,5 @@ package ghttp
|
||||
// 控制器接口
|
||||
type Controller interface {
|
||||
Init(*Request)
|
||||
Shut(*Request)
|
||||
Shut()
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ func (r *Request) GetRequest(key string, def ... []string) []string {
|
||||
func (r *Request) GetRequestVar(key string, def ... interface{}) gvar.VarRead {
|
||||
value := r.GetRequest(key)
|
||||
if value != nil {
|
||||
return gvar.New(value, false)
|
||||
return gvar.New(value[0], false)
|
||||
}
|
||||
if len(def) > 0 {
|
||||
return gvar.New(def[0], false)
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"gitee.com/johng/gf/g/container/gtype"
|
||||
"gitee.com/johng/gf/g/os/gcache"
|
||||
"gitee.com/johng/gf/g/os/genv"
|
||||
"gitee.com/johng/gf/g/os/gfile"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gproc"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
@ -90,13 +91,13 @@ type (
|
||||
}
|
||||
|
||||
// pattern与回调函数的绑定map
|
||||
handlerMap map[string]*handlerItem
|
||||
handlerMap = map[string]*handlerItem
|
||||
|
||||
// HTTP注册函数
|
||||
HandlerFunc func(r *Request)
|
||||
HandlerFunc = func(r *Request)
|
||||
|
||||
// 文件描述符map
|
||||
listenerFdMap map[string]string
|
||||
listenerFdMap = map[string]string
|
||||
)
|
||||
|
||||
const (
|
||||
@ -254,6 +255,10 @@ func (s *Server) Start() error {
|
||||
}
|
||||
})
|
||||
}
|
||||
// 是否处于开发环境
|
||||
if gfile.MainPkgPath() != "" {
|
||||
glog.Debug("GF notices that you're in develop environment, so error logs are auto enabled to stdout.")
|
||||
}
|
||||
|
||||
// 打印展示路由表
|
||||
s.DumpRoutesMap()
|
||||
|
||||
@ -160,35 +160,46 @@ func (s *Server) searchStaticFile(uri string) (filePath string, isDir bool) {
|
||||
|
||||
// 初始化控制器
|
||||
func (s *Server) callServeHandler(h *handlerItem, r *Request) {
|
||||
if h.faddr == nil {
|
||||
c := reflect.New(h.ctype)
|
||||
s.niceCall(func() {
|
||||
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
})
|
||||
s.niceCall(func() {
|
||||
c.MethodByName(h.fname).Call(nil)
|
||||
})
|
||||
s.niceCall(func() {
|
||||
c.MethodByName("Shut").Call(nil)
|
||||
})
|
||||
} else {
|
||||
if h.finit != nil {
|
||||
s.niceCall(func() {
|
||||
h.finit(r)
|
||||
})
|
||||
}
|
||||
s.niceCall(func() {
|
||||
h.faddr(r)
|
||||
})
|
||||
if h.fshut != nil {
|
||||
s.niceCall(func() {
|
||||
h.fshut(r)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 友好地调用方法
|
||||
func (s *Server) niceCall(f func()) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil && e != gEXCEPTION_EXIT {
|
||||
panic(e)
|
||||
}
|
||||
}()
|
||||
if h.faddr == nil {
|
||||
// 新建一个控制器对象处理请求
|
||||
c := reflect.New(h.ctype)
|
||||
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if !r.IsExited() {
|
||||
c.MethodByName(h.fname).Call(nil)
|
||||
c.MethodByName("Shut").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
}
|
||||
} else {
|
||||
// 是否有初始化及完成回调方法
|
||||
if h.finit != nil {
|
||||
h.finit(r)
|
||||
}
|
||||
if !r.IsExited() {
|
||||
h.faddr(r)
|
||||
if h.fshut != nil {
|
||||
h.fshut(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
f()
|
||||
}
|
||||
|
||||
// http server静态文件处理,path可以为相对路径也可以为绝对路径
|
||||
func (s *Server)serveFile(r *Request, path string) {
|
||||
func (s *Server) serveFile(r *Request, path string) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
r.Response.WriteStatus(http.StatusForbidden)
|
||||
|
||||
@ -9,6 +9,7 @@ package ghttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/os/gfile"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -36,7 +37,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
|
||||
r.Response.WriteStatus(http.StatusInternalServerError)
|
||||
|
||||
// 错误输出默认是开启的
|
||||
if !s.IsErrorLogEnabled() {
|
||||
if !s.IsErrorLogEnabled() && gfile.MainPkgPath() == "" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -56,5 +57,9 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
|
||||
s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content)
|
||||
} else {
|
||||
s.logger.Cat("error").Backtrace(true, 2).Error(content)
|
||||
// 开发环境下(MainPkgPath)自动输出错误信息到标准输出
|
||||
if gfile.MainPkgPath() != "" {
|
||||
s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,24 +20,28 @@ import (
|
||||
|
||||
|
||||
// 解析pattern
|
||||
func (s *Server)parsePattern(pattern string) (domain, method, uri string, err error) {
|
||||
uri = pattern
|
||||
func (s *Server)parsePattern(pattern string) (domain, method, path string, err error) {
|
||||
path = strings.TrimSpace(pattern)
|
||||
domain = gDEFAULT_DOMAIN
|
||||
method = gDEFAULT_METHOD
|
||||
if array, err := gregex.MatchString(`([a-zA-Z]+):(.+)`, pattern); len(array) > 1 && err == nil {
|
||||
method = array[1]
|
||||
uri = array[2]
|
||||
path = strings.TrimSpace(array[2])
|
||||
if v := strings.TrimSpace(array[1]); v != "" {
|
||||
method = v
|
||||
}
|
||||
}
|
||||
if array, err := gregex.MatchString(`(.+)@([\w\.\-]+)`, uri); len(array) > 1 && err == nil {
|
||||
uri = array[1]
|
||||
domain = array[2]
|
||||
if array, err := gregex.MatchString(`(.+)@([\w\.\-]+)`, path); len(array) > 1 && err == nil {
|
||||
path = strings.TrimSpace(array[1])
|
||||
if v := strings.TrimSpace(array[2]); v != "" {
|
||||
domain = v
|
||||
}
|
||||
}
|
||||
if uri == "" {
|
||||
if path == "" {
|
||||
err = errors.New("invalid pattern")
|
||||
}
|
||||
// 去掉末尾的"/"符号,与路由匹配时处理一致
|
||||
if uri != "/" {
|
||||
uri = strings.TrimRight(uri, "/")
|
||||
if path != "/" {
|
||||
path = strings.TrimRight(path, "/")
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -293,7 +297,6 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
|
||||
regrule += `/[^/]+`
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
case '*':
|
||||
if len(v) > 1 {
|
||||
regrule += `/{0,1}(.*)`
|
||||
@ -303,7 +306,6 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
|
||||
regrule += `/{0,1}.*`
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
// 特殊字符替换
|
||||
v = gstr.ReplaceByMap(v, map[string]string{
|
||||
|
||||
214
g/net/ghttp/ghttp_server_router_group.go
Normal file
214
g/net/ghttp/ghttp_server_router_group.go
Normal file
@ -0,0 +1,214 @@
|
||||
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
// 分组路由管理.
|
||||
|
||||
package ghttp
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 分组路由对象
|
||||
type RouterGroup struct {
|
||||
server *Server // Server
|
||||
domain *Domain // Domain
|
||||
prefix string // URI前缀
|
||||
}
|
||||
|
||||
// 分组路由批量绑定项
|
||||
type GroupItem = []interface{}
|
||||
|
||||
// 获取分组路由对象
|
||||
func (s *Server) Group(prefix...string) *RouterGroup {
|
||||
if len(prefix) > 0 {
|
||||
return &RouterGroup{
|
||||
server : s,
|
||||
prefix : prefix[0],
|
||||
}
|
||||
}
|
||||
return &RouterGroup{}
|
||||
}
|
||||
|
||||
// 获取分组路由对象
|
||||
func (d *Domain) Group(prefix...string) *RouterGroup {
|
||||
if len(prefix) > 0 {
|
||||
return &RouterGroup{
|
||||
domain : d,
|
||||
prefix : prefix[0],
|
||||
}
|
||||
}
|
||||
return &RouterGroup{}
|
||||
}
|
||||
|
||||
// 执行分组路由批量绑定
|
||||
func (g *RouterGroup) Bind(group string, items []GroupItem) {
|
||||
for _, item := range items {
|
||||
if len(item) < 3 {
|
||||
glog.Fatalfln("invalid router item: %s", item)
|
||||
}
|
||||
if strings.EqualFold(gconv.String(item[0]), "REST") {
|
||||
g.bind("REST", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2])
|
||||
} else {
|
||||
if len(item) > 3 {
|
||||
g.bind("HANDLER", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2], item[3])
|
||||
} else {
|
||||
g.bind("HANDLER", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 绑定所有的HTTP Method请求方式
|
||||
func (g *RouterGroup) ALL(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", gDEFAULT_METHOD + ":" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) GET(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "GET:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) PUT(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "PUT:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) POST(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "POST:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) DELETE(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "DELETE:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) PATCH(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "PATCH:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) HEAD(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "HEAD:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) CONNECT(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "CONNECT:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) OPTIONS(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "OPTIONS:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
func (g *RouterGroup) TRACE(pattern string, object interface{}, params...interface{}) {
|
||||
g.bind("HANDLER", "TRACE:" + pattern, object, params...)
|
||||
}
|
||||
|
||||
// REST路由注册
|
||||
func (g *RouterGroup) REST(pattern string, object interface{}) {
|
||||
g.bind("REST", pattern, object)
|
||||
}
|
||||
|
||||
// 执行路由绑定
|
||||
func (g *RouterGroup) bind(bindType string, pattern string, object interface{}, params...interface{}) {
|
||||
// 注册路由处理
|
||||
if len(g.prefix) > 0 {
|
||||
domain, method, path, err := g.server.parsePattern(pattern)
|
||||
if err != nil {
|
||||
glog.Fatalfln("invalid pattern: %s", pattern)
|
||||
}
|
||||
if bindType == "HANDLER" {
|
||||
pattern = g.server.serveHandlerKey(method, g.prefix + "/" + strings.TrimLeft(path, "/"), domain)
|
||||
} else {
|
||||
pattern = g.prefix + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
}
|
||||
methods := gconv.Strings(params)
|
||||
// 判断是否事件回调注册
|
||||
if _, ok := object.(HandlerFunc); ok && len(methods) > 0 {
|
||||
bindType = "HOOK"
|
||||
}
|
||||
switch bindType {
|
||||
case "HANDLER":
|
||||
if h, ok := object.(HandlerFunc); ok {
|
||||
if g.server != nil {
|
||||
g.server.BindHandler(pattern, h)
|
||||
} else {
|
||||
g.domain.BindHandler(pattern, h)
|
||||
}
|
||||
} else if g.isController(object) {
|
||||
if len(methods) > 0 {
|
||||
if g.server != nil {
|
||||
g.server.BindControllerMethod(pattern, object.(Controller), methods[0])
|
||||
} else {
|
||||
g.domain.BindControllerMethod(pattern, object.(Controller), methods[0])
|
||||
}
|
||||
} else {
|
||||
if g.server != nil {
|
||||
g.server.BindController(pattern, object.(Controller))
|
||||
} else {
|
||||
g.domain.BindController(pattern, object.(Controller))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(methods) > 0 {
|
||||
if g.server != nil {
|
||||
g.server.BindObjectMethod(pattern, object, methods[0])
|
||||
} else {
|
||||
g.domain.BindObjectMethod(pattern, object, methods[0])
|
||||
}
|
||||
} else {
|
||||
if g.server != nil {
|
||||
g.server.BindObject(pattern, object)
|
||||
} else {
|
||||
g.domain.BindObject(pattern, object)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "REST":
|
||||
if g.isController(object) {
|
||||
if g.server != nil {
|
||||
g.server.BindControllerRest(pattern, object.(Controller))
|
||||
} else {
|
||||
g.domain.BindControllerRest(pattern, object.(Controller))
|
||||
}
|
||||
} else {
|
||||
if g.server != nil {
|
||||
g.server.BindObjectRest(pattern, object)
|
||||
} else {
|
||||
g.domain.BindObjectRest(pattern, object)
|
||||
}
|
||||
}
|
||||
case "HOOK":
|
||||
if h, ok := object.(HandlerFunc); ok {
|
||||
if g.server != nil {
|
||||
g.server.BindHookHandler(pattern, methods[0], h)
|
||||
} else {
|
||||
g.domain.BindHookHandler(pattern, methods[0], h)
|
||||
}
|
||||
} else {
|
||||
glog.Fatalfln("invalid hook handler for pattern:%s", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断给定对象是否控制器对象:
|
||||
// 控制器必须包含以下公开的属性对象:Request/Response/Server/Cookie/Session/View.
|
||||
func (g *RouterGroup) isController(value interface{}) bool {
|
||||
// 首先判断是否满足控制器接口定义
|
||||
if _, ok := value.(Controller); !ok {
|
||||
return false
|
||||
}
|
||||
// 其次检查控制器的必需属性
|
||||
v := reflect.ValueOf(value)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.FieldByName("Request").IsValid() && v.FieldByName("Response").IsValid() &&
|
||||
v.FieldByName("Server").IsValid() && v.FieldByName("Cookie").IsValid() &&
|
||||
v.FieldByName("Session").IsValid() && v.FieldByName("View").IsValid() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -25,7 +25,6 @@ func (s *Server)BindHookHandler(pattern string, hook string, handler HandlerFunc
|
||||
fname : "",
|
||||
faddr : handler,
|
||||
}, hook)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 通过map批量绑定回调函数
|
||||
|
||||
@ -82,5 +82,4 @@ func (s *Server) Run() error {
|
||||
go s.handler(NewConnByNetConn(conn))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -85,7 +85,6 @@ func (c *Conn) Send(data []byte, retry...Retry) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 接收数据
|
||||
|
||||
@ -77,5 +77,4 @@ func (s *Server) Run() error {
|
||||
for {
|
||||
s.handler(NewConnByNetConn(conn))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -63,8 +63,8 @@ func newMemCache(lruCap...int) *memCache {
|
||||
closed : gtype.NewBool(),
|
||||
}
|
||||
if len(lruCap) > 0 {
|
||||
c.lru = newMemCacheLru(c)
|
||||
c.cap = lruCap[0]
|
||||
c.lru = newMemCacheLru(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@ -370,7 +370,7 @@ func homeWindows() (string, error) {
|
||||
return home, nil
|
||||
}
|
||||
|
||||
// 获取入口函数文件所在目录(main包文件目录),
|
||||
// 获取入口函数文件所在目录(main包文件目录),
|
||||
// **仅对源码开发环境有效(即仅对生成该可执行文件的系统下有效)**
|
||||
func MainPkgPath() string {
|
||||
path := mainPkgPath.Val()
|
||||
@ -401,6 +401,7 @@ func MainPkgPath() string {
|
||||
if p == f {
|
||||
break
|
||||
}
|
||||
// 会自动扫描源码,寻找main包
|
||||
if paths, err := ScanDir(p, "*.go"); err == nil && len(paths) > 0 {
|
||||
for _, path := range paths {
|
||||
if gregex.IsMatchString(`package\s+main`, GetContents(path)) {
|
||||
|
||||
@ -105,8 +105,6 @@ func GetNextCharOffsetByPath(path string, char byte, start int64) int64 {
|
||||
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
|
||||
defer f.Close()
|
||||
return GetNextCharOffset(f, char, start)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@ -124,8 +122,6 @@ func GetBinContentsTilCharByPath(path string, char byte, start int64) ([]byte, i
|
||||
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
|
||||
defer f.Close()
|
||||
return GetBinContentsTilChar(f, char, start)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
return nil, -1
|
||||
}
|
||||
@ -144,8 +140,6 @@ func GetBinContentsByTwoOffsetsByPath(path string, start int64, end int64) []byt
|
||||
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
|
||||
defer f.Close()
|
||||
return GetBinContentsByTwoOffsets(f, start, end)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -7,7 +7,6 @@
|
||||
package gfsnotify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/container/glist"
|
||||
)
|
||||
|
||||
@ -32,8 +31,8 @@ func (w *Watcher) startWatchLoop() {
|
||||
return struct {}{}
|
||||
}, REPEAT_EVENT_FILTER_INTERVAL)
|
||||
|
||||
case err := <- w.watcher.Errors:
|
||||
fmt.Errorf("error: %s\n" + err.Error());
|
||||
case <- w.watcher.Errors:
|
||||
//fmt.Fprintf(os.Stderr, "[gfsnotify] error: %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@ -31,7 +31,7 @@ type Logger struct {
|
||||
file *gtype.String // 日志文件名称格式
|
||||
level *gtype.Int // 日志输出等级
|
||||
btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数
|
||||
btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印
|
||||
btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1,表示默认打印逻辑 - 错误才打印)
|
||||
printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等)
|
||||
alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端
|
||||
}
|
||||
@ -65,7 +65,7 @@ func New() *Logger {
|
||||
file : gtype.NewString(gDEFAULT_FILE_FORMAT),
|
||||
level : gtype.NewInt(defaultLevel.Val()),
|
||||
btSkip : gtype.NewInt(),
|
||||
btEnabled : gtype.NewBool(true),
|
||||
btStatus : gtype.NewInt(-1),
|
||||
printHeader : gtype.NewBool(true),
|
||||
alsoStdPrint : gtype.NewBool(true),
|
||||
}
|
||||
@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger {
|
||||
file : l.file.Clone(),
|
||||
level : l.level.Clone(),
|
||||
btSkip : l.btSkip.Clone(),
|
||||
btEnabled : l.btEnabled.Clone(),
|
||||
btStatus : l.btStatus.Clone(),
|
||||
printHeader : l.printHeader.Clone(),
|
||||
alsoStdPrint : l.alsoStdPrint.Clone(),
|
||||
}
|
||||
@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) {
|
||||
}
|
||||
|
||||
func (l *Logger) SetBacktrace(enabled bool) {
|
||||
l.btEnabled.Set(enabled)
|
||||
if enabled {
|
||||
l.btStatus.Set(1)
|
||||
} else {
|
||||
l.btStatus.Set(0)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 设置BacktraceSkip
|
||||
@ -136,6 +141,13 @@ func (l *Logger) getFilePointer() *gfpool.File {
|
||||
file, _ := gregex.ReplaceStringFunc(`{.+?}`, l.file.Val(), func(s string) string {
|
||||
return gtime.Now().Format(strings.Trim(s, "{}"))
|
||||
})
|
||||
// 如果日志目录不存在则创建目录路径
|
||||
if !gfile.Exists(path) {
|
||||
if err := gfile.Mkdir(path); err != nil {
|
||||
fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
fpath := path + gfile.Separator + file
|
||||
if fp, err := gfpool.Open(fpath, gDEFAULT_FILE_POOL_FLAGS, gDEFAULT_FPOOL_PERM, gDEFAULT_FPOOL_EXPIRE); err == nil {
|
||||
return fp
|
||||
@ -151,7 +163,7 @@ func (l *Logger) SetPath(path string) error {
|
||||
// 如果目录不存在,则递归创建
|
||||
if !gfile.Exists(path) {
|
||||
if err := gfile.Mkdir(path); err != nil {
|
||||
fmt.Fprintln(os.Stderr, fmt.Sprintf(`glog mkdir "%s" failed: %s`, path, err.Error()))
|
||||
fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error()))
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -220,24 +232,35 @@ func (l *Logger) stdPrint(s string) {
|
||||
// 核心打印数据方法(标准错误)
|
||||
func (l *Logger) errPrint(s string) {
|
||||
// 记录调用回溯信息
|
||||
if l.btEnabled.Val() {
|
||||
tracestr := l.GetBacktrace()
|
||||
if tracestr != "" {
|
||||
backtrace := "Backtrace:" + ln + tracestr
|
||||
if s[len(s) - 1] == byte('\n') {
|
||||
s = s + backtrace + ln
|
||||
} else {
|
||||
s = s + ln + backtrace + ln
|
||||
}
|
||||
}
|
||||
status := l.btStatus.Val()
|
||||
if status == -1 || status == 1 {
|
||||
s = l.appendBacktrace(s)
|
||||
}
|
||||
// 防止串日志情况,这里不使用stderr,而是使用stdout
|
||||
l.print(os.Stdout, s)
|
||||
}
|
||||
|
||||
// 输出内容中添加回溯信息
|
||||
func (l *Logger) appendBacktrace(s string, skip...int) string {
|
||||
trace := l.GetBacktrace(skip...)
|
||||
if trace != "" {
|
||||
backtrace := "Backtrace:" + ln + trace
|
||||
if len(s) > 0 {
|
||||
if s[len(s)-1] == byte('\n') {
|
||||
s = s + backtrace + ln
|
||||
} else {
|
||||
s = s + ln + backtrace + ln
|
||||
}
|
||||
} else {
|
||||
s = backtrace
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// 直接打印回溯信息,参数skip表示调用端往上多少级开始回溯
|
||||
func (l *Logger) PrintBacktrace(skip...int) {
|
||||
l.Println(l.GetBacktrace(skip...))
|
||||
l.Println(l.appendBacktrace("", skip...))
|
||||
}
|
||||
|
||||
// 获取文件调用回溯字符串,参数skip表示调用端往上多少级开始回溯
|
||||
|
||||
@ -121,7 +121,6 @@ func getShell() string {
|
||||
}
|
||||
return path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 获取当前系统默认shell执行指令的option参数
|
||||
@ -132,7 +131,6 @@ func getShellOption() string {
|
||||
default:
|
||||
return "-c"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 从环境变量PATH中搜索可执行文件
|
||||
|
||||
@ -62,7 +62,10 @@ func startTcpListening() {
|
||||
|
||||
// TCP数据通信处理回调函数
|
||||
func tcpServiceHandler(conn *gtcp.Conn) {
|
||||
var retry = gtcp.Retry{3, 10}
|
||||
retry := gtcp.Retry {
|
||||
Count : 3,
|
||||
Interval: 10,
|
||||
}
|
||||
for {
|
||||
var result []byte
|
||||
buffer, err := conn.Recv(-1, retry)
|
||||
@ -97,32 +100,32 @@ func tcpServiceHandler(conn *gtcp.Conn) {
|
||||
}
|
||||
|
||||
// 数据解包,防止黏包
|
||||
// 数据格式:总长度(24bit)|发送进程PID(16bit)|接收进程PID(16bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
|
||||
// 数据格式:总长度(24bit)|发送进程PID(24bit)|接收进程PID(24bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
|
||||
func bufferToMsgs(buffer []byte) []*Msg {
|
||||
s := 0
|
||||
msgs := make([]*Msg, 0)
|
||||
for s < len(buffer) {
|
||||
// 长度解析及校验
|
||||
length := gbinary.DecodeToInt(buffer[s : s + 3])
|
||||
if length < 12 || length > len(buffer) {
|
||||
if length < 14 || length > len(buffer) {
|
||||
s++
|
||||
continue
|
||||
}
|
||||
// 分组信息解析
|
||||
groupLen := gbinary.DecodeToInt(buffer[s + 7 : s + 8])
|
||||
groupLen := gbinary.DecodeToInt(buffer[s + 9 : s + 10])
|
||||
// checksum校验(仅对参数做校验,提高校验效率)
|
||||
checksum1 := gbinary.DecodeToUint32(buffer[s + 8 + groupLen : s + 8 + groupLen + 4])
|
||||
checksum2 := gtcp.Checksum(buffer[s + 8 + groupLen + 4 : s + length])
|
||||
checksum1 := gbinary.DecodeToUint32(buffer[s + 10 + groupLen : s + 10 + groupLen + 4])
|
||||
checksum2 := gtcp.Checksum(buffer[s + 10 + groupLen + 4 : s + length])
|
||||
if checksum1 != checksum2 {
|
||||
s++
|
||||
continue
|
||||
}
|
||||
// 接收进程PID校验
|
||||
if Pid() == gbinary.DecodeToInt(buffer[s + 5 : s + 7]) {
|
||||
if Pid() == gbinary.DecodeToInt(buffer[s + 6 : s + 9]) {
|
||||
msgs = append(msgs, &Msg {
|
||||
Pid : gbinary.DecodeToInt(buffer[s + 3 : s + 5]),
|
||||
Data : buffer[s + 8 + groupLen + 4 : s + length],
|
||||
Group : string(buffer[s + 8 : s + 8 + groupLen]),
|
||||
Pid : gbinary.DecodeToInt(buffer[s + 3 : s + 6]),
|
||||
Data : buffer[s + 10 + groupLen + 4 : s + length],
|
||||
Group : string(buffer[s + 10 : s + 10 + groupLen]),
|
||||
})
|
||||
}
|
||||
s += length
|
||||
|
||||
@ -27,16 +27,16 @@ const (
|
||||
)
|
||||
|
||||
// 向指定gproc进程发送数据
|
||||
// 数据格式:总长度(24bit)|发送进程PID(16bit)|接收进程PID(16bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
|
||||
// 数据格式:总长度(24bit)|发送进程PID(24bit)|接收进程PID(24bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
|
||||
func Send(pid int, data []byte, group...string) error {
|
||||
groupName := gPROC_COMM_DEAFULT_GRUOP_NAME
|
||||
if len(group) > 0 {
|
||||
groupName = group[0]
|
||||
}
|
||||
buffer := make([]byte, 0)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(3, len(groupName) + len(data) + 12)...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(2, Pid())...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(2, pid)...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(3, len(groupName) + len(data) + 14)...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(3, Pid())...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(3, pid)...)
|
||||
buffer = append(buffer, gbinary.EncodeByLength(1, len(groupName))...)
|
||||
buffer = append(buffer, []byte(groupName)...)
|
||||
buffer = append(buffer, gbinary.EncodeUint32(gtcp.Checksum(data))...)
|
||||
|
||||
@ -13,18 +13,18 @@ import (
|
||||
"gitee.com/johng/gf/g/os/grpool"
|
||||
)
|
||||
|
||||
func increment1() {
|
||||
func increment() {
|
||||
for i := 0; i < 1000000; i++ {}
|
||||
}
|
||||
|
||||
func BenchmarkGrpool_1(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
grpool.Add(increment1)
|
||||
grpool.Add(increment)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGoroutine_1(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
go increment1()
|
||||
go increment()
|
||||
}
|
||||
}
|
||||
@ -15,10 +15,6 @@ import (
|
||||
|
||||
var n = 500000
|
||||
|
||||
func increment2() {
|
||||
for i := 0; i < 1000000; i++ {}
|
||||
}
|
||||
|
||||
func BenchmarkGrpool2(b *testing.B) {
|
||||
b.N = n
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -1,35 +0,0 @@
|
||||
// Copyright 2017 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
package grpool_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"runtime"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func increment() {
|
||||
for i := 0; i < 100000; i++ {}
|
||||
}
|
||||
|
||||
//func Test_GrpoolMemUsage(t *testing.T) {
|
||||
// for i := 0; i < n; i++ {
|
||||
// grpool.Add(increment)
|
||||
// }
|
||||
// mem := runtime.MemStats{}
|
||||
// runtime.ReadMemStats(&mem)
|
||||
// fmt.Println("mem usage:", mem.TotalAlloc/1024)
|
||||
//}
|
||||
|
||||
func Test_GroroutineMemUsage(t *testing.T) {
|
||||
for i := 0; i < n; i++ {
|
||||
go increment()
|
||||
}
|
||||
mem := runtime.MemStats{}
|
||||
runtime.ReadMemStats(&mem)
|
||||
fmt.Println("mem usage:", mem.TotalAlloc/1024)
|
||||
}
|
||||
@ -123,7 +123,7 @@ func (t *Time) ToZone(zone string) *Time {
|
||||
t.Time = t.Time.In(l)
|
||||
return t
|
||||
} else {
|
||||
panic(err)
|
||||
//panic(err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,10 +49,12 @@ var viewObj *View
|
||||
// 初始化默认的视图对象
|
||||
func checkAndInitDefaultView() {
|
||||
if viewObj == nil {
|
||||
if gfile.SelfDir() != gfile.TempDir() {
|
||||
// gfile.MainPkgPath() 用以判断是否开发环境
|
||||
mainPkgPath := gfile.MainPkgPath()
|
||||
if gfile.MainPkgPath() == "" {
|
||||
viewObj = New(gfile.SelfDir())
|
||||
} else {
|
||||
viewObj = New()
|
||||
viewObj = New(mainPkgPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,8 +99,11 @@ func Map(i interface{}, noTagCheck...bool) map[string]interface{} {
|
||||
rt := rv.Type()
|
||||
name := ""
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
if name = rt.Field(i).Tag.Get("json"); name == "" {
|
||||
name = rt.Field(i).Name
|
||||
// 检查json tag
|
||||
if len(noTagCheck) == 0 || !noTagCheck[0] {
|
||||
if name = rt.Field(i).Tag.Get("json"); name == "" {
|
||||
name = rt.Field(i).Name
|
||||
}
|
||||
}
|
||||
m[name] = rv.Field(i).Interface()
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
package gconv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
@ -77,12 +76,11 @@ func Ints(i interface{}) []int {
|
||||
for _, v := range i.([]interface{}) {
|
||||
array = append(array, Int(v))
|
||||
}
|
||||
default:
|
||||
return []int{Int(i)}
|
||||
}
|
||||
if len(array) > 0 {
|
||||
return array
|
||||
}
|
||||
return array
|
||||
}
|
||||
return []int{Int(i)}
|
||||
}
|
||||
|
||||
// 任意类型转换为[]string类型
|
||||
@ -95,71 +93,70 @@ func Strings(i interface{}) []string {
|
||||
} else {
|
||||
array := make([]string, 0)
|
||||
switch i.(type) {
|
||||
case []int:
|
||||
for _, v := range i.([]int) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int8:
|
||||
for _, v := range i.([]int8) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int16:
|
||||
for _, v := range i.([]int16) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int32:
|
||||
for _, v := range i.([]int32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int64:
|
||||
for _, v := range i.([]int64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint:
|
||||
for _, v := range i.([]uint) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint8:
|
||||
for _, v := range i.([]uint8) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint16:
|
||||
for _, v := range i.([]uint16) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint32:
|
||||
for _, v := range i.([]uint32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint64:
|
||||
for _, v := range i.([]uint64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []bool:
|
||||
for _, v := range i.([]bool) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []float32:
|
||||
for _, v := range i.([]float32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []float64:
|
||||
for _, v := range i.([]float64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []interface{}:
|
||||
for _, v := range i.([]interface{}) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
}
|
||||
if len(array) > 0 {
|
||||
return array
|
||||
case []int:
|
||||
for _, v := range i.([]int) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int8:
|
||||
for _, v := range i.([]int8) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int16:
|
||||
for _, v := range i.([]int16) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int32:
|
||||
for _, v := range i.([]int32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []int64:
|
||||
for _, v := range i.([]int64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint:
|
||||
for _, v := range i.([]uint) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint8:
|
||||
for _, v := range i.([]uint8) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint16:
|
||||
for _, v := range i.([]uint16) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint32:
|
||||
for _, v := range i.([]uint32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []uint64:
|
||||
for _, v := range i.([]uint64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []bool:
|
||||
for _, v := range i.([]bool) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []float32:
|
||||
for _, v := range i.([]float32) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []float64:
|
||||
for _, v := range i.([]float64) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
case []interface{}:
|
||||
for _, v := range i.([]interface{}) {
|
||||
array = append(array, String(v))
|
||||
}
|
||||
default:
|
||||
return []string{String(i)}
|
||||
}
|
||||
return array
|
||||
}
|
||||
return []string{fmt.Sprintf("%v", i)}
|
||||
}
|
||||
|
||||
// 任意类型转换为[]float64类型
|
||||
// 将类型转换为[]float64类型
|
||||
func Floats(i interface{}) []float64 {
|
||||
if i == nil {
|
||||
return nil
|
||||
@ -225,12 +222,11 @@ func Floats(i interface{}) []float64 {
|
||||
for _, v := range i.([]interface{}) {
|
||||
array = append(array, Float64(v))
|
||||
}
|
||||
default:
|
||||
return []float64{Float64(i)}
|
||||
}
|
||||
if len(array) > 0 {
|
||||
return array
|
||||
}
|
||||
return array
|
||||
}
|
||||
return []float64{Float64(i)}
|
||||
}
|
||||
|
||||
// 任意类型转换为[]interface{}类型
|
||||
@ -318,11 +314,10 @@ func Interfaces(i interface{}) []interface{} {
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
array = append(array, rv.Field(i).Interface())
|
||||
}
|
||||
default:
|
||||
return []interface{}{i}
|
||||
}
|
||||
}
|
||||
if len(array) > 0 {
|
||||
return array
|
||||
}
|
||||
return array
|
||||
}
|
||||
return []interface{}{i}
|
||||
}
|
||||
@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er
|
||||
structFieldValue := elem.FieldByName(name)
|
||||
// 键名与对象属性匹配检测,map中如果有struct不存在的属性,那么不做处理,直接return
|
||||
if !structFieldValue.IsValid() {
|
||||
//return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name))
|
||||
return nil
|
||||
}
|
||||
// CanSet的属性必须为公开属性(首字母大写)
|
||||
if !structFieldValue.CanSet() {
|
||||
//return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name))
|
||||
return nil
|
||||
}
|
||||
// 必须将value转换为struct属性的数据类型,这里必须用到gconv包
|
||||
@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
|
||||
structFieldValue := elem.FieldByIndex([]int{index})
|
||||
// 键名与对象属性匹配检测
|
||||
if !structFieldValue.IsValid() {
|
||||
//return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index))
|
||||
return nil
|
||||
}
|
||||
// CanSet的属性必须为公开属性(首字母大写)
|
||||
if !structFieldValue.CanSet() {
|
||||
//return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index))
|
||||
return nil
|
||||
}
|
||||
// 必须将value转换为struct属性的数据类型,这里必须用到gconv包
|
||||
|
||||
@ -14,21 +14,26 @@ import (
|
||||
|
||||
// 将变量i转换为time.Time类型
|
||||
func Time(i interface{}, format...string) time.Time {
|
||||
s := String(i)
|
||||
// 优先使用用户输入日期格式进行转换
|
||||
if len(format) > 0 {
|
||||
t, _ := gtime.StrToTimeFormat(s, format[0])
|
||||
return t.Time
|
||||
}
|
||||
if gstr.IsNumeric(s) {
|
||||
return gtime.NewFromTimeStamp(Int64(s)).Time
|
||||
} else {
|
||||
t, _ := gtime.StrToTime(s)
|
||||
return t.Time
|
||||
}
|
||||
return GTime(i, format...).Time
|
||||
}
|
||||
|
||||
// 将变量i转换为time.Time类型
|
||||
func TimeDuration(i interface{}) time.Duration {
|
||||
return time.Duration(Int64(i))
|
||||
}
|
||||
|
||||
// 将变量i转换为time.Time类型
|
||||
func GTime(i interface{}, format...string) *gtime.Time {
|
||||
s := String(i)
|
||||
// 优先使用用户输入日期格式进行转换
|
||||
if len(format) > 0 {
|
||||
t, _ := gtime.StrToTimeFormat(s, format[0])
|
||||
return t
|
||||
}
|
||||
if gstr.IsNumeric(s) {
|
||||
return gtime.NewFromTimeStamp(Int64(s))
|
||||
} else {
|
||||
t, _ := gtime.StrToTime(s)
|
||||
return t
|
||||
}
|
||||
}
|
||||
@ -9,7 +9,6 @@ package grand
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -34,8 +33,8 @@ func init() {
|
||||
bufferChan <- binary.LittleEndian.Uint64(buffer[i : i + 8])
|
||||
i ++
|
||||
}
|
||||
// 充分利用缓冲区数据,字节倒序生成,随机索引递增
|
||||
step = int(time.Now().UnixNano())%n
|
||||
// 充分利用缓冲区数据,随机索引递增
|
||||
step = int(buffer[0])%10
|
||||
for i := 0; i < n - 8; {
|
||||
bufferChan <- binary.BigEndian.Uint64(buffer[i : i + 8])
|
||||
i += step
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 字符串替换
|
||||
// 字符串替换(大小写敏感)
|
||||
func Replace(origin, search, replace string, count...int) string {
|
||||
n := -1
|
||||
if len(count) > 0 {
|
||||
@ -22,7 +22,7 @@ func Replace(origin, search, replace string, count...int) string {
|
||||
return strings.Replace(origin, search, replace, n)
|
||||
}
|
||||
|
||||
// 使用map进行字符串替换
|
||||
// 使用map进行字符串替换(大小写敏感)
|
||||
func ReplaceByMap(origin string, replaces map[string]string) string {
|
||||
result := origin
|
||||
for k, v := range replaces {
|
||||
|
||||
32
g/util/gtest/gtest.go
Normal file
32
g/util/gtest/gtest.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
// Package gtest provides useful test utils.
|
||||
// 测试模块.
|
||||
package gtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"os"
|
||||
)
|
||||
|
||||
// 断言判断
|
||||
func Assert(value, expect interface{}) {
|
||||
if gconv.String(value) != gconv.String(expect) {
|
||||
glog.Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect)
|
||||
glog.Header(false).PrintBacktrace(1)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// 提示错误并退出
|
||||
func Fatal(message...interface{}) {
|
||||
glog.Println(`[FATAL]`, fmt.Sprint(message...))
|
||||
glog.Header(false).PrintBacktrace(1)
|
||||
os.Exit(1)
|
||||
}
|
||||
@ -4,13 +4,14 @@
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
// 其他工具包
|
||||
// 工具包
|
||||
package gutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"reflect"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"runtime"
|
||||
@ -43,7 +44,7 @@ func Dump(i...interface{}) {
|
||||
if err := encoder.Encode(v); err == nil {
|
||||
fmt.Print(buffer.String())
|
||||
} else {
|
||||
fmt.Errorf("%s", err.Error())
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
}
|
||||
}
|
||||
//fmt.Println()
|
||||
|
||||
@ -14,7 +14,12 @@ import (
|
||||
// 检测键值对参数Map,
|
||||
// rules参数支持 []string / map[string]string 类型,前面一种类型支持返回校验结果顺序(具体格式参考struct tag),后一种不支持;
|
||||
// rules参数中得 map[string]string 是一个2维的关联数组,第一维键名为参数键名,第二维为带有错误的校验规则名称,值为错误信息。
|
||||
func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg) *Error {
|
||||
func CheckMap(params interface{}, rules interface{}, msgs...CustomMsg) *Error {
|
||||
// 将参数转换为 map[string]interface{}类型
|
||||
data := gconv.Map(params)
|
||||
if data == nil {
|
||||
return newErrorStr("invalid_params", "invalid params type: convert to map[string]interface{} failed")
|
||||
}
|
||||
// 真实校验规则数据结构
|
||||
checkRules := make(map[string]string)
|
||||
// 真实自定义错误信息数据结构
|
||||
@ -73,11 +78,15 @@ func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg
|
||||
value := (interface{})(nil)
|
||||
// 这里的rule变量为多条校验规则,不包含名字或者错误信息定义
|
||||
for key, rule := range checkRules {
|
||||
// 如果规则为空,那么不执行校验
|
||||
if len(rule) == 0 {
|
||||
continue
|
||||
}
|
||||
value = nil
|
||||
if v, ok := params[key]; ok {
|
||||
if v, ok := data[key]; ok {
|
||||
value = v
|
||||
}
|
||||
if e := Check(value, rule, customMsgs[key], params); e != nil {
|
||||
if e := Check(value, rule, customMsgs[key], data); e != nil {
|
||||
_, item := e.FirstItem()
|
||||
// 如果值为nil|"",并且不需要require*验证时,其他验证失效
|
||||
if value == nil || gconv.String(value) == "" {
|
||||
|
||||
@ -22,7 +22,7 @@ type Error struct {
|
||||
type ErrorMap map[string]map[string]string
|
||||
|
||||
|
||||
// 创建一个校验错误对象指针
|
||||
// 创建一个校验错误对象指针(校验错误)
|
||||
func newError(rules []string, errors map[string]map[string]string) *Error {
|
||||
return &Error {
|
||||
rules : rules,
|
||||
@ -30,6 +30,18 @@ func newError(rules []string, errors map[string]map[string]string) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
// 创建一个校验错误对象指针(内部错误)
|
||||
func newErrorStr(key, err string) *Error {
|
||||
return &Error {
|
||||
rules : nil,
|
||||
errors : map[string]map[string]string{
|
||||
"__gvalid__" : {
|
||||
key: err,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获得规则与错误信息的map; 当校验结果为多条数据校验时,返回第一条错误map(此时类似FirstItem)
|
||||
func (e *Error) Map() map[string]string {
|
||||
_, m := e.FirstItem()
|
||||
|
||||
20
geg/container/garray/sorted_string_array1.go
Normal file
20
geg/container/garray/sorted_string_array1.go
Normal file
@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/container/garray"
|
||||
)
|
||||
|
||||
func main() {
|
||||
array := garray.NewSortedStringArray(0, false)
|
||||
array.Add("1")
|
||||
array.Add("2")
|
||||
array.Add("3")
|
||||
array.Add("4")
|
||||
array.Add("5")
|
||||
array.Add("6")
|
||||
array.Add("7")
|
||||
array.Add("8")
|
||||
array.Add("9")
|
||||
g.Dump(array.Slice())
|
||||
}
|
||||
26
geg/container/garray/sorted_string_array2.go
Normal file
26
geg/container/garray/sorted_string_array2.go
Normal file
@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/container/garray"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
array := garray.NewSortedStringArray(0, false)
|
||||
array.Add("/api/ctl/show")
|
||||
array.Add("/api/ctl/post")
|
||||
array.Add("/api/obj/rest")
|
||||
array.Add("/api/handler")
|
||||
array.Add("/api/obj/delete")
|
||||
array.Add("/api/obj/show")
|
||||
array.Add("/api/obj/my-show")
|
||||
array.Add("/api/*")
|
||||
array.Add("/api/ctl/rest")
|
||||
array.Add("/api/ctl/my-show")
|
||||
g.Dump(array.Slice())
|
||||
|
||||
fmt.Println(strings.Compare("/api/ctl/post", "/api/*"))
|
||||
fmt.Println(strings.Compare("/api/*", "/api/ctl/my-show"))
|
||||
}
|
||||
@ -9,7 +9,7 @@ import (
|
||||
|
||||
// 本文件用于gf框架的mysql数据库操作示例,不作为单元测试使用
|
||||
|
||||
var db *gdb.Db
|
||||
var db gdb.DB
|
||||
|
||||
// 初始化配置及创建数据库
|
||||
func init () {
|
||||
@ -17,7 +17,7 @@ func init () {
|
||||
Host : "127.0.0.1",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "8692651",
|
||||
Pass : "12345678",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
|
||||
@ -11,7 +11,7 @@ func main() {
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Pass: "123456",
|
||||
Pass: "12345678",
|
||||
Name: "test",
|
||||
Type: "mysql",
|
||||
Role: "master",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/encoding/gparser"
|
||||
@ -11,18 +12,14 @@ func main() {
|
||||
Host : "127.0.0.1",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "123456",
|
||||
Pass : "12345678",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
Charset : "utf8",
|
||||
})
|
||||
db, err := gdb.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
one, err := db.Table("user").Where("uid=?", 1).One()
|
||||
db := g.DB()
|
||||
one, err := db.Table("user").Where("id=?", 1).One()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -33,7 +30,7 @@ func main() {
|
||||
|
||||
// 自定义方法方法转换为json/xml
|
||||
jsonContent, _ := gparser.VarToJson(one.ToMap())
|
||||
fmt.Println(jsonContent)
|
||||
fmt.Println(string(jsonContent))
|
||||
xmlContent, _ := gparser.VarToXml(one.ToMap())
|
||||
fmt.Println(xmlContent)
|
||||
fmt.Println(string(xmlContent))
|
||||
}
|
||||
@ -1,28 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode {
|
||||
Host : "127.0.0.1",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "12345678",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
Charset : "utf8",
|
||||
MaxIdleConnCount : 10,
|
||||
MaxOpenConnCount : 10,
|
||||
MaxConnLifetime : 10,
|
||||
})
|
||||
db, err := gdb.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db := g.DB()
|
||||
db.SetMaxIdleConns(10)
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetConnMaxLifetime(10)
|
||||
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
|
||||
@ -2,28 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode {
|
||||
Host : "192.168.1.11",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "8692651",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
Charset : "utf8",
|
||||
})
|
||||
db, err := gdb.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db := g.DB()
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
r, _ := db.Table("user").All()
|
||||
r, _ := db.Table("test").Where("id IN (?)", []interface{}{1, 2}).All()
|
||||
if r != nil {
|
||||
fmt.Println(r.ToList())
|
||||
}
|
||||
|
||||
78
geg/net/ghttp/server/router/group/group1.go
Normal file
78
geg/net/ghttp/server/router/group/group1.go
Normal file
@ -0,0 +1,78 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/frame/gmvc"
|
||||
"gitee.com/johng/gf/g/net/ghttp"
|
||||
)
|
||||
|
||||
type Object struct {}
|
||||
|
||||
func (o *Object) Show(r *ghttp.Request) {
|
||||
r.Response.Writeln("Object Show")
|
||||
}
|
||||
|
||||
func (o *Object) Delete(r *ghttp.Request) {
|
||||
r.Response.Writeln("Object REST Delete")
|
||||
}
|
||||
|
||||
func (o *Object) Shut(r *ghttp.Request) {
|
||||
r.Response.Writeln("Object Shut")
|
||||
}
|
||||
|
||||
type Controller struct {
|
||||
gmvc.Controller
|
||||
}
|
||||
|
||||
func (c *Controller) Show() {
|
||||
c.Response.Writeln("Controller Show")
|
||||
}
|
||||
|
||||
func (c *Controller) Post() {
|
||||
c.Response.Writeln("Controller REST Post")
|
||||
}
|
||||
|
||||
func (c *Controller) Shut() {
|
||||
c.Response.Writeln("Controller Shut")
|
||||
}
|
||||
|
||||
func Handler(r *ghttp.Request) {
|
||||
r.Response.Writeln("Handler")
|
||||
}
|
||||
|
||||
func HookHandler(r *ghttp.Request) {
|
||||
r.Response.Writeln("Hook Handler")
|
||||
}
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
obj := new(Object)
|
||||
ctl := new(Controller)
|
||||
|
||||
// 分组路由方法注册
|
||||
//g := s.Group("/api")
|
||||
//g.ALL ("*", HookHandler, ghttp.HOOK_BEFORE_SERVE)
|
||||
//g.ALL ("/handler", Handler)
|
||||
//g.ALL ("/ctl", ctl)
|
||||
//g.GET ("/ctl/my-show", ctl, "Show")
|
||||
//g.REST("/ctl/rest", ctl)
|
||||
//g.ALL ("/obj", obj)
|
||||
//g.GET ("/obj/my-show", obj, "Show")
|
||||
//g.REST("/obj/rest", obj)
|
||||
|
||||
// 分组路由批量注册
|
||||
s.Group("/api").Bind("/api", []ghttp.GroupItem{
|
||||
|
||||
{"ALL", "/handler", Handler},
|
||||
{"ALL", "/ctl", ctl},
|
||||
{"GET", "/ctl/my-show", ctl, "Show"},
|
||||
{"REST", "/ctl/rest", ctl},
|
||||
{"ALL", "/obj", obj},
|
||||
{"GET", "/obj/my-show", obj, "Show"},
|
||||
{"REST", "/obj/rest", obj},
|
||||
{"ALL", "*", HookHandler, ghttp.HOOK_BEFORE_SERVE},
|
||||
})
|
||||
|
||||
s.SetPort(8199)
|
||||
s.Run()
|
||||
}
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"gitee.com/johng/gf/g/net/gtcp"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -39,7 +40,7 @@ func main() {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Errorf("ERROR: %s\n", err.Error())
|
||||
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"gitee.com/johng/gf/g/net/gtcp"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -11,6 +12,6 @@ func main() {
|
||||
fmt.Println(string(data))
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Errorf("ERROR: %s\n", err.Error())
|
||||
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
19
geg/os/glog/glog_pool.go
Normal file
19
geg/os/glog/glog_pool.go
Normal file
@ -0,0 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 测试删除日志文件是否会重建日志文件
|
||||
func main() {
|
||||
path := "/Users/john/Temp/test"
|
||||
glog.SetPath(path)
|
||||
for {
|
||||
glog.Println(gtime.Now().String())
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,19 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
)
|
||||
type Registry struct {
|
||||
Method string
|
||||
Uri string
|
||||
Handler interface{}
|
||||
Object interface{}
|
||||
}
|
||||
|
||||
func BindGroup(group string, routers [][]interface{}) {
|
||||
|
||||
}
|
||||
|
||||
type User struct { }
|
||||
type Order struct { }
|
||||
type Product struct { }
|
||||
|
||||
func HookFunc() {
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
//s := `1544180795 -- s_has_sess -- 41570504 -decryptSess- 41570504__iuVycRYg9qE3y7CsSgGZH1K2nxTdjPZN4fXot65zHIEmULO0Ow6LweJp5raWl8Ft -postSess- eyJpdiI6IkFwSWZ3eXFMcGxBZE5JcWF4aXh0M3c9PSIsInZhbHVlIjoiV3ZLeGduMnRoRkFZdmxHTzM5ZzdyU1JHWDMycmZlRERvNnFkaUR0SitlRjBrZnlYR1JvS2puTGZNUThSeFR0bWtlT3pza0l0elFqRk5mdXF6XC9FWWpWZnljVjdJbHd3dTRybEhldHZHTk5DQ015dlpYNHljNmxKMWJTRUVpY0E4IiwibWFjIjoiOTkxMzIxOTRhMGUxZWZiODM4NWZjNDZjYmVhNWY2NjhlZDZkNmVlNjY1MTE2N2VhZDAzYzY4NDJmZGFkMjY5YyJ9 -- 0 -- B8105CF2-1588-4753-9F86-9B8C36EB1842 -- iPhone 7 -- 12.1 -- 6.8.7 -- i -- 10.111.153.5 -- medlinker -- service -- unknown
|
||||
//`
|
||||
// s := `[08-Dec-2018 13:35:03 Asia/Shanghai] Medlinker\Services\Message\MessageService|updateUserInfo|用户头像 URI 不能为空 in /var/www/med-d2d/app/Services/Message/RongCloudService.php on line 851`
|
||||
s := `[2018-12-01 13:35:03 Asia/Shanghai] 1544180795 Medlinker\Services\Message\MessageService|updateUserInfo|用户头像 URI 不能为空 in /var/www/med-d2d/app/Services/Message/RongCloudService.php on line 851`
|
||||
//m, e := gregex.MatchString(`/var/log/medlinker/[\w\-\_]+/(.+?)/{0,1}[\d\-\_]*\.log`, `/var/log/medlinker/med-questionnaire/nginx/error/access-20181206.log`)
|
||||
//m, e := gregex.MatchString(`/var/log/medlinker/[\w\-\_]+/(.+?)/{0,1}[\d\-\_]*\.log`, `/var/log/medlinker/med-questionnaire/storagelogs/events/sqlLog/2018-12-06.log`)
|
||||
m, e := gregex.MatchString(`(.*?((\d{4}[-/\.]\d{2}[-/\.]\d{2}|\d{1,2}[-/\.][A-Za-z]{3,}[-/\.]\d{4})[:\sT-]*\d{0,2}:{0,1}\d{0,2}:{0,1}\d{0,2}\.{0,1}\d{0,9}[\sZ]{0,1}[\+-]{0,1}[:\d]*|\d{10}).+)`, s)
|
||||
fmt.Println(e)
|
||||
g.Dump(m)
|
||||
user := new(User)
|
||||
BindGroup("/api", [][]interface{} {
|
||||
{"ALL", "/*", HookFunc, "BeforeServe"},
|
||||
{"ALL", "/order", new(Order)},
|
||||
{"REST", "/product", new(Product)},
|
||||
{"GET", "/user/register", "Register", user},
|
||||
{"GET", "/user/reset-pass", "ResetPassword", user},
|
||||
{"POST", "/user/reset-pass", "ResetPassword", user},
|
||||
{"POST", "/user/login", "Login", user},
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var s = "/name/john///."
|
||||
var c = "./"
|
||||
|
||||
func t1(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
for _, cut := range c {
|
||||
for s[len(s) - 1] == uint8(cut) {
|
||||
s = s[:len(s) - 1]
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func t2(s string) string {
|
||||
return strings.TrimRight(s, c)
|
||||
}
|
||||
|
||||
func Benchmark_t1(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
t1(s)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_t2(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
t2(s)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,35 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/util/gvalid"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Uid int `gvalid:"uid @integer|min:1"`
|
||||
Name string `gvalid:"name @required|length:6,30#请输入用户名称|用户名称长度非法"`
|
||||
Pass1 string `gvalid:"password1@required|password3"`
|
||||
Pass2 string `gvalid:"password2@required|password3|same:password1#||两次密码不一致,请重新输入"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
user := &User{
|
||||
Name : "john",
|
||||
Pass1: "Abc123!@#",
|
||||
Pass2: "123",
|
||||
}
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode{
|
||||
Type : "mysql",
|
||||
Linkinfo : "root:12345678@tcp(127.0.0.1:3306)/test",
|
||||
})
|
||||
|
||||
// 使用结构体定义的校验规则和错误提示进行校验
|
||||
g.Dump(gvalid.CheckStruct(user, nil).Maps())
|
||||
|
||||
// 自定义校验规则和错误提示,对定义的特定校验规则和错误提示进行覆盖
|
||||
rules := map[string]string {
|
||||
"Uid" : "required",
|
||||
if r, err := g.Database().GetOne("select now() as time"); err != nil {
|
||||
glog.Error("Mysql Init Select Now : %v", err)
|
||||
} else {
|
||||
fmt.Println(r.ToMap())
|
||||
}
|
||||
msgs := map[string]interface{} {
|
||||
"Pass2" : map[string]string {
|
||||
"password3" : "名称不能为空",
|
||||
},
|
||||
}
|
||||
g.Dump(gvalid.CheckStruct(user, rules, msgs).Maps())
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
for i := 0; i < 10; i++ {
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Println(grand.Rand(0, 99999))
|
||||
}
|
||||
}
|
||||
|
||||
30
third/github.com/DataDog/zstd/.travis.yml
Normal file
30
third/github.com/DataDog/zstd/.travis.yml
Normal file
@ -0,0 +1,30 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
matrix:
|
||||
include:
|
||||
name: "Go 1.11.x CentOS 32bits"
|
||||
language: go
|
||||
go: 1.11.x
|
||||
os: linux
|
||||
services:
|
||||
- docker
|
||||
script:
|
||||
# Please update Go version in travis_test_32 as needed
|
||||
- "docker run -i -v \"${PWD}:/zstd\" toopher/centos-i386:centos6 /bin/bash -c \"linux32 --32bit i386 /zstd/travis_test_32.sh\""
|
||||
|
||||
install:
|
||||
- "wget https://github.com/DataDog/zstd/files/2246767/mr.zip"
|
||||
- "unzip mr.zip"
|
||||
script:
|
||||
- "go build"
|
||||
- "PAYLOAD=`pwd`/mr go test -v"
|
||||
- "PAYLOAD=`pwd`/mr go test -bench ."
|
||||
27
third/github.com/DataDog/zstd/LICENSE
Normal file
27
third/github.com/DataDog/zstd/LICENSE
Normal file
@ -0,0 +1,27 @@
|
||||
Simplified BSD License
|
||||
|
||||
Copyright (c) 2016, Datadog <info@datadoghq.com>
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
* Neither the name of the copyright holder nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
120
third/github.com/DataDog/zstd/README.md
Normal file
120
third/github.com/DataDog/zstd/README.md
Normal file
@ -0,0 +1,120 @@
|
||||
# Zstd Go Wrapper
|
||||
|
||||
[C Zstd Homepage](https://github.com/Cyan4973/zstd)
|
||||
|
||||
The current headers and C files are from *v1.3.4* (Commit
|
||||
[2555975](https://github.com/facebook/zstd/releases/tag/v1.3.4)).
|
||||
|
||||
## Usage
|
||||
|
||||
There are two main APIs:
|
||||
|
||||
* simple Compress/Decompress
|
||||
* streaming API (io.Reader/io.Writer)
|
||||
|
||||
The compress/decompress APIs mirror that of lz4, while the streaming API was
|
||||
designed to be a drop-in replacement for zlib.
|
||||
|
||||
### Simple `Compress/Decompress`
|
||||
|
||||
|
||||
```go
|
||||
// Compress compresses the byte array given in src and writes it to dst.
|
||||
// If you already have a buffer allocated, you can pass it to prevent allocation
|
||||
// If not, you can pass nil as dst.
|
||||
// If the buffer is too small, it will be reallocated, resized, and returned bu the function
|
||||
// If dst is nil, this will allocate the worst case size (CompressBound(src))
|
||||
Compress(dst, src []byte) ([]byte, error)
|
||||
```
|
||||
|
||||
```go
|
||||
// CompressLevel is the same as Compress but you can pass another compression level
|
||||
CompressLevel(dst, src []byte, level int) ([]byte, error)
|
||||
```
|
||||
|
||||
```go
|
||||
// Decompress will decompress your payload into dst.
|
||||
// If you already have a buffer allocated, you can pass it to prevent allocation
|
||||
// If not, you can pass nil as dst (allocates a 4*src size as default).
|
||||
// If the buffer is too small, it will retry 3 times by doubling the dst size
|
||||
// After max retries, it will switch to the slower stream API to be sure to be able
|
||||
// to decompress. Currently switches if compression ratio > 4*2**3=32.
|
||||
Decompress(dst, src []byte) ([]byte, error)
|
||||
```
|
||||
|
||||
### Stream API
|
||||
|
||||
```go
|
||||
// NewWriter creates a new object that can optionally be initialized with
|
||||
// a precomputed dictionary. If dict is nil, compress without a dictionary.
|
||||
// The dictionary array should not be changed during the use of this object.
|
||||
// You MUST CALL Close() to write the last bytes of a zstd stream and free C objects.
|
||||
NewWriter(w io.Writer) *Writer
|
||||
NewWriterLevel(w io.Writer, level int) *Writer
|
||||
NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
|
||||
|
||||
// Write compresses the input data and write it to the underlying writer
|
||||
(w *Writer) Write(p []byte) (int, error)
|
||||
|
||||
// Close flushes the buffer and frees C zstd objects
|
||||
(w *Writer) Close() error
|
||||
```
|
||||
|
||||
```go
|
||||
// NewReader returns a new io.ReadCloser that will decompress data from the
|
||||
// underlying reader. If a dictionary is provided to NewReaderDict, it must
|
||||
// not be modified until Close is called. It is the caller's responsibility
|
||||
// to call Close, which frees up C objects.
|
||||
NewReader(r io.Reader) io.ReadCloser
|
||||
NewReaderDict(r io.Reader, dict []byte) io.ReadCloser
|
||||
```
|
||||
|
||||
### Benchmarks (benchmarked with v0.5.0)
|
||||
|
||||
The author of Zstd also wrote lz4. Zstd is intended to occupy a speed/ratio
|
||||
level similar to what zlib currently provides. In our tests, the can always
|
||||
be made to be better than zlib by chosing an appropriate level while still
|
||||
keeping compression and decompression time faster than zlib.
|
||||
|
||||
You can run the benchmarks against your own payloads by using the Go benchmarks tool.
|
||||
Just export your payload filepath as the `PAYLOAD` environment variable and run the benchmarks:
|
||||
|
||||
```go
|
||||
go test -bench .
|
||||
```
|
||||
|
||||
Compression of a 7Mb pdf zstd (this wrapper) vs [czlib](https://github.com/DataDog/czlib):
|
||||
```
|
||||
BenchmarkCompression 5 221056624 ns/op 67.34 MB/s
|
||||
BenchmarkDecompression 100 18370416 ns/op 810.32 MB/s
|
||||
|
||||
BenchmarkFzlibCompress 2 610156603 ns/op 24.40 MB/s
|
||||
BenchmarkFzlibDecompress 20 81195246 ns/op 183.33 MB/s
|
||||
```
|
||||
|
||||
Ratio is also better by a margin of ~20%.
|
||||
Compression speed is always better than zlib on all the payloads we tested;
|
||||
However, [czlib](https://github.com/DataDog/czlib) has optimisations that make it
|
||||
faster at decompressiong small payloads:
|
||||
|
||||
```
|
||||
Testing with size: 11... czlib: 8.97 MB/s, zstd: 3.26 MB/s
|
||||
Testing with size: 27... czlib: 23.3 MB/s, zstd: 8.22 MB/s
|
||||
Testing with size: 62... czlib: 31.6 MB/s, zstd: 19.49 MB/s
|
||||
Testing with size: 141... czlib: 74.54 MB/s, zstd: 42.55 MB/s
|
||||
Testing with size: 323... czlib: 155.14 MB/s, zstd: 99.39 MB/s
|
||||
Testing with size: 739... czlib: 235.9 MB/s, zstd: 216.45 MB/s
|
||||
Testing with size: 1689... czlib: 116.45 MB/s, zstd: 345.64 MB/s
|
||||
Testing with size: 3858... czlib: 176.39 MB/s, zstd: 617.56 MB/s
|
||||
Testing with size: 8811... czlib: 254.11 MB/s, zstd: 824.34 MB/s
|
||||
Testing with size: 20121... czlib: 197.43 MB/s, zstd: 1339.11 MB/s
|
||||
Testing with size: 45951... czlib: 201.62 MB/s, zstd: 1951.57 MB/s
|
||||
```
|
||||
|
||||
zstd starts to shine with payloads > 1KB
|
||||
|
||||
### Stability - Current state: STABLE
|
||||
|
||||
The C library seems to be pretty stable and according to the author has been tested and fuzzed.
|
||||
|
||||
For the Go wrapper, the test cover most usual cases and we have succesfully tested it on all staging and prod data.
|
||||
30
third/github.com/DataDog/zstd/ZSTD_LICENSE
Normal file
30
third/github.com/DataDog/zstd/ZSTD_LICENSE
Normal file
@ -0,0 +1,30 @@
|
||||
BSD License
|
||||
|
||||
For Zstandard software
|
||||
|
||||
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name Facebook nor the names of its contributors may be used to
|
||||
endorse or promote products derived from this software without specific
|
||||
prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
471
third/github.com/DataDog/zstd/bitstream.h
Normal file
471
third/github.com/DataDog/zstd/bitstream.h
Normal file
@ -0,0 +1,471 @@
|
||||
/* ******************************************************************
|
||||
bitstream
|
||||
Part of FSE library
|
||||
header file (to include)
|
||||
Copyright (C) 2013-2017, Yann Collet.
|
||||
|
||||
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
You can contact the author at :
|
||||
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
|
||||
****************************************************************** */
|
||||
#ifndef BITSTREAM_H_MODULE
|
||||
#define BITSTREAM_H_MODULE
|
||||
|
||||
#if defined (__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* This API consists of small unitary functions, which must be inlined for best performance.
|
||||
* Since link-time-optimization is not available for all compilers,
|
||||
* these functions are defined into a .h to be included.
|
||||
*/
|
||||
|
||||
/*-****************************************
|
||||
* Dependencies
|
||||
******************************************/
|
||||
#include "mem.h" /* unaligned access routines */
|
||||
#include "error_private.h" /* error codes and messages */
|
||||
|
||||
|
||||
/*-*************************************
|
||||
* Debug
|
||||
***************************************/
|
||||
#if defined(BIT_DEBUG) && (BIT_DEBUG>=1)
|
||||
# include <assert.h>
|
||||
#else
|
||||
# ifndef assert
|
||||
# define assert(condition) ((void)0)
|
||||
# endif
|
||||
#endif
|
||||
|
||||
|
||||
/*=========================================
|
||||
* Target specific
|
||||
=========================================*/
|
||||
#if defined(__BMI__) && defined(__GNUC__)
|
||||
# include <immintrin.h> /* support for bextr (experimental) */
|
||||
#endif
|
||||
|
||||
#define STREAM_ACCUMULATOR_MIN_32 25
|
||||
#define STREAM_ACCUMULATOR_MIN_64 57
|
||||
#define STREAM_ACCUMULATOR_MIN ((U32)(MEM_32bits() ? STREAM_ACCUMULATOR_MIN_32 : STREAM_ACCUMULATOR_MIN_64))
|
||||
|
||||
|
||||
/*-******************************************
|
||||
* bitStream encoding API (write forward)
|
||||
********************************************/
|
||||
/* bitStream can mix input from multiple sources.
|
||||
* A critical property of these streams is that they encode and decode in **reverse** direction.
|
||||
* So the first bit sequence you add will be the last to be read, like a LIFO stack.
|
||||
*/
|
||||
typedef struct
|
||||
{
|
||||
size_t bitContainer;
|
||||
unsigned bitPos;
|
||||
char* startPtr;
|
||||
char* ptr;
|
||||
char* endPtr;
|
||||
} BIT_CStream_t;
|
||||
|
||||
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, void* dstBuffer, size_t dstCapacity);
|
||||
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
|
||||
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC);
|
||||
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC);
|
||||
|
||||
/* Start with initCStream, providing the size of buffer to write into.
|
||||
* bitStream will never write outside of this buffer.
|
||||
* `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code.
|
||||
*
|
||||
* bits are first added to a local register.
|
||||
* Local register is size_t, hence 64-bits on 64-bits systems, or 32-bits on 32-bits systems.
|
||||
* Writing data into memory is an explicit operation, performed by the flushBits function.
|
||||
* Hence keep track how many bits are potentially stored into local register to avoid register overflow.
|
||||
* After a flushBits, a maximum of 7 bits might still be stored into local register.
|
||||
*
|
||||
* Avoid storing elements of more than 24 bits if you want compatibility with 32-bits bitstream readers.
|
||||
*
|
||||
* Last operation is to close the bitStream.
|
||||
* The function returns the final size of CStream in bytes.
|
||||
* If data couldn't fit into `dstBuffer`, it will return a 0 ( == not storable)
|
||||
*/
|
||||
|
||||
|
||||
/*-********************************************
|
||||
* bitStream decoding API (read backward)
|
||||
**********************************************/
|
||||
typedef struct
|
||||
{
|
||||
size_t bitContainer;
|
||||
unsigned bitsConsumed;
|
||||
const char* ptr;
|
||||
const char* start;
|
||||
const char* limitPtr;
|
||||
} BIT_DStream_t;
|
||||
|
||||
typedef enum { BIT_DStream_unfinished = 0,
|
||||
BIT_DStream_endOfBuffer = 1,
|
||||
BIT_DStream_completed = 2,
|
||||
BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */
|
||||
/* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */
|
||||
|
||||
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize);
|
||||
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits);
|
||||
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD);
|
||||
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD);
|
||||
|
||||
|
||||
/* Start by invoking BIT_initDStream().
|
||||
* A chunk of the bitStream is then stored into a local register.
|
||||
* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t).
|
||||
* You can then retrieve bitFields stored into the local register, **in reverse order**.
|
||||
* Local register is explicitly reloaded from memory by the BIT_reloadDStream() method.
|
||||
* A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished.
|
||||
* Otherwise, it can be less than that, so proceed accordingly.
|
||||
* Checking if DStream has reached its end can be performed with BIT_endOfDStream().
|
||||
*/
|
||||
|
||||
|
||||
/*-****************************************
|
||||
* unsafe API
|
||||
******************************************/
|
||||
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
|
||||
/* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */
|
||||
|
||||
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC);
|
||||
/* unsafe version; does not check buffer overflow */
|
||||
|
||||
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits);
|
||||
/* faster, but works only if nbBits >= 1 */
|
||||
|
||||
|
||||
|
||||
/*-**************************************************************
|
||||
* Internal functions
|
||||
****************************************************************/
|
||||
MEM_STATIC unsigned BIT_highbit32 (U32 val)
|
||||
{
|
||||
assert(val != 0);
|
||||
{
|
||||
# if defined(_MSC_VER) /* Visual */
|
||||
unsigned long r=0;
|
||||
_BitScanReverse ( &r, val );
|
||||
return (unsigned) r;
|
||||
# elif defined(__GNUC__) && (__GNUC__ >= 3) /* Use GCC Intrinsic */
|
||||
return 31 - __builtin_clz (val);
|
||||
# else /* Software version */
|
||||
static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29,
|
||||
11, 14, 16, 18, 22, 25, 3, 30,
|
||||
8, 12, 20, 28, 15, 17, 24, 7,
|
||||
19, 27, 23, 6, 26, 5, 4, 31 };
|
||||
U32 v = val;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27];
|
||||
# endif
|
||||
}
|
||||
}
|
||||
|
||||
/*===== Local Constants =====*/
|
||||
static const unsigned BIT_mask[] = {
|
||||
0, 1, 3, 7, 0xF, 0x1F,
|
||||
0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF,
|
||||
0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF,
|
||||
0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF,
|
||||
0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF,
|
||||
0x3FFFFFFF, 0x7FFFFFFF}; /* up to 31 bits */
|
||||
#define BIT_MASK_SIZE (sizeof(BIT_mask) / sizeof(BIT_mask[0]))
|
||||
|
||||
/*-**************************************************************
|
||||
* bitStream encoding
|
||||
****************************************************************/
|
||||
/*! BIT_initCStream() :
|
||||
* `dstCapacity` must be > sizeof(size_t)
|
||||
* @return : 0 if success,
|
||||
* otherwise an error code (can be tested using ERR_isError()) */
|
||||
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC,
|
||||
void* startPtr, size_t dstCapacity)
|
||||
{
|
||||
bitC->bitContainer = 0;
|
||||
bitC->bitPos = 0;
|
||||
bitC->startPtr = (char*)startPtr;
|
||||
bitC->ptr = bitC->startPtr;
|
||||
bitC->endPtr = bitC->startPtr + dstCapacity - sizeof(bitC->bitContainer);
|
||||
if (dstCapacity <= sizeof(bitC->bitContainer)) return ERROR(dstSize_tooSmall);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*! BIT_addBits() :
|
||||
* can add up to 31 bits into `bitC`.
|
||||
* Note : does not check for register overflow ! */
|
||||
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC,
|
||||
size_t value, unsigned nbBits)
|
||||
{
|
||||
MEM_STATIC_ASSERT(BIT_MASK_SIZE == 32);
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos;
|
||||
bitC->bitPos += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_addBitsFast() :
|
||||
* works only if `value` is _clean_, meaning all high bits above nbBits are 0 */
|
||||
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC,
|
||||
size_t value, unsigned nbBits)
|
||||
{
|
||||
assert((value>>nbBits) == 0);
|
||||
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
bitC->bitContainer |= value << bitC->bitPos;
|
||||
bitC->bitPos += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_flushBitsFast() :
|
||||
* assumption : bitContainer has not overflowed
|
||||
* unsafe version; does not check buffer overflow */
|
||||
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC)
|
||||
{
|
||||
size_t const nbBytes = bitC->bitPos >> 3;
|
||||
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
|
||||
bitC->ptr += nbBytes;
|
||||
assert(bitC->ptr <= bitC->endPtr);
|
||||
bitC->bitPos &= 7;
|
||||
bitC->bitContainer >>= nbBytes*8;
|
||||
}
|
||||
|
||||
/*! BIT_flushBits() :
|
||||
* assumption : bitContainer has not overflowed
|
||||
* safe version; check for buffer overflow, and prevents it.
|
||||
* note : does not signal buffer overflow.
|
||||
* overflow will be revealed later on using BIT_closeCStream() */
|
||||
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC)
|
||||
{
|
||||
size_t const nbBytes = bitC->bitPos >> 3;
|
||||
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
|
||||
bitC->ptr += nbBytes;
|
||||
if (bitC->ptr > bitC->endPtr) bitC->ptr = bitC->endPtr;
|
||||
bitC->bitPos &= 7;
|
||||
bitC->bitContainer >>= nbBytes*8;
|
||||
}
|
||||
|
||||
/*! BIT_closeCStream() :
|
||||
* @return : size of CStream, in bytes,
|
||||
* or 0 if it could not fit into dstBuffer */
|
||||
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC)
|
||||
{
|
||||
BIT_addBitsFast(bitC, 1, 1); /* endMark */
|
||||
BIT_flushBits(bitC);
|
||||
if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */
|
||||
return (bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0);
|
||||
}
|
||||
|
||||
|
||||
/*-********************************************************
|
||||
* bitStream decoding
|
||||
**********************************************************/
|
||||
/*! BIT_initDStream() :
|
||||
* Initialize a BIT_DStream_t.
|
||||
* `bitD` : a pointer to an already allocated BIT_DStream_t structure.
|
||||
* `srcSize` must be the *exact* size of the bitStream, in bytes.
|
||||
* @return : size of stream (== srcSize), or an errorCode if a problem is detected
|
||||
*/
|
||||
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize)
|
||||
{
|
||||
if (srcSize < 1) { memset(bitD, 0, sizeof(*bitD)); return ERROR(srcSize_wrong); }
|
||||
|
||||
bitD->start = (const char*)srcBuffer;
|
||||
bitD->limitPtr = bitD->start + sizeof(bitD->bitContainer);
|
||||
|
||||
if (srcSize >= sizeof(bitD->bitContainer)) { /* normal case */
|
||||
bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer);
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr);
|
||||
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
|
||||
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */
|
||||
if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ }
|
||||
} else {
|
||||
bitD->ptr = bitD->start;
|
||||
bitD->bitContainer = *(const BYTE*)(bitD->start);
|
||||
switch(srcSize)
|
||||
{
|
||||
case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16);
|
||||
/* fall-through */
|
||||
|
||||
case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24);
|
||||
/* fall-through */
|
||||
|
||||
case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32);
|
||||
/* fall-through */
|
||||
|
||||
case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24;
|
||||
/* fall-through */
|
||||
|
||||
case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16;
|
||||
/* fall-through */
|
||||
|
||||
case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8;
|
||||
/* fall-through */
|
||||
|
||||
default: break;
|
||||
}
|
||||
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
|
||||
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0;
|
||||
if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */
|
||||
}
|
||||
bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8;
|
||||
}
|
||||
|
||||
return srcSize;
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getUpperBits(size_t bitContainer, U32 const start)
|
||||
{
|
||||
return bitContainer >> start;
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits)
|
||||
{
|
||||
#if defined(__BMI__) && defined(__GNUC__) && __GNUC__*1000+__GNUC_MINOR__ >= 4008 /* experimental */
|
||||
# if defined(__x86_64__)
|
||||
if (sizeof(bitContainer)==8)
|
||||
return _bextr_u64(bitContainer, start, nbBits);
|
||||
else
|
||||
# endif
|
||||
return _bextr_u32(bitContainer, start, nbBits);
|
||||
#else
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
return (bitContainer >> start) & BIT_mask[nbBits];
|
||||
#endif
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits)
|
||||
{
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
return bitContainer & BIT_mask[nbBits];
|
||||
}
|
||||
|
||||
/*! BIT_lookBits() :
|
||||
* Provides next n bits from local register.
|
||||
* local register is not modified.
|
||||
* On 32-bits, maxNbBits==24.
|
||||
* On 64-bits, maxNbBits==56.
|
||||
* @return : value extracted */
|
||||
MEM_STATIC size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
#if defined(__BMI__) && defined(__GNUC__) /* experimental; fails if bitD->bitsConsumed + nbBits > sizeof(bitD->bitContainer)*8 */
|
||||
return BIT_getMiddleBits(bitD->bitContainer, (sizeof(bitD->bitContainer)*8) - bitD->bitsConsumed - nbBits, nbBits);
|
||||
#else
|
||||
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
|
||||
return ((bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> 1) >> ((regMask-nbBits) & regMask);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*! BIT_lookBitsFast() :
|
||||
* unsafe version; only works if nbBits >= 1 */
|
||||
MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
|
||||
assert(nbBits >= 1);
|
||||
return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask);
|
||||
}
|
||||
|
||||
MEM_STATIC void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
bitD->bitsConsumed += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_readBits() :
|
||||
* Read (consume) next n bits from local register and update.
|
||||
* Pay attention to not read more than nbBits contained into local register.
|
||||
* @return : extracted value. */
|
||||
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
size_t const value = BIT_lookBits(bitD, nbBits);
|
||||
BIT_skipBits(bitD, nbBits);
|
||||
return value;
|
||||
}
|
||||
|
||||
/*! BIT_readBitsFast() :
|
||||
* unsafe version; only works only if nbBits >= 1 */
|
||||
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
size_t const value = BIT_lookBitsFast(bitD, nbBits);
|
||||
assert(nbBits >= 1);
|
||||
BIT_skipBits(bitD, nbBits);
|
||||
return value;
|
||||
}
|
||||
|
||||
/*! BIT_reloadDStream() :
|
||||
* Refill `bitD` from buffer previously set in BIT_initDStream() .
|
||||
* This function is safe, it guarantees it will not read beyond src buffer.
|
||||
* @return : status of `BIT_DStream_t` internal register.
|
||||
* when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */
|
||||
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD)
|
||||
{
|
||||
if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */
|
||||
return BIT_DStream_overflow;
|
||||
|
||||
if (bitD->ptr >= bitD->limitPtr) {
|
||||
bitD->ptr -= bitD->bitsConsumed >> 3;
|
||||
bitD->bitsConsumed &= 7;
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr);
|
||||
return BIT_DStream_unfinished;
|
||||
}
|
||||
if (bitD->ptr == bitD->start) {
|
||||
if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer;
|
||||
return BIT_DStream_completed;
|
||||
}
|
||||
/* start < ptr < limitPtr */
|
||||
{ U32 nbBytes = bitD->bitsConsumed >> 3;
|
||||
BIT_DStream_status result = BIT_DStream_unfinished;
|
||||
if (bitD->ptr - nbBytes < bitD->start) {
|
||||
nbBytes = (U32)(bitD->ptr - bitD->start); /* ptr > start */
|
||||
result = BIT_DStream_endOfBuffer;
|
||||
}
|
||||
bitD->ptr -= nbBytes;
|
||||
bitD->bitsConsumed -= nbBytes*8;
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr); /* reminder : srcSize > sizeof(bitD->bitContainer), otherwise bitD->ptr == bitD->start */
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/*! BIT_endOfDStream() :
|
||||
* @return : 1 if DStream has _exactly_ reached its end (all bits consumed).
|
||||
*/
|
||||
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* DStream)
|
||||
{
|
||||
return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer)*8));
|
||||
}
|
||||
|
||||
#if defined (__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* BITSTREAM_H_MODULE */
|
||||
111
third/github.com/DataDog/zstd/compiler.h
Normal file
111
third/github.com/DataDog/zstd/compiler.h
Normal file
@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
#ifndef ZSTD_COMPILER_H
|
||||
#define ZSTD_COMPILER_H
|
||||
|
||||
/*-*******************************************************
|
||||
* Compiler specifics
|
||||
*********************************************************/
|
||||
/* force inlining */
|
||||
#if defined (__GNUC__) || defined(__cplusplus) || defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */
|
||||
# define INLINE_KEYWORD inline
|
||||
#else
|
||||
# define INLINE_KEYWORD
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
# define FORCE_INLINE_ATTR __attribute__((always_inline))
|
||||
#elif defined(_MSC_VER)
|
||||
# define FORCE_INLINE_ATTR __forceinline
|
||||
#else
|
||||
# define FORCE_INLINE_ATTR
|
||||
#endif
|
||||
|
||||
/**
|
||||
* FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant
|
||||
* parameters. They must be inlined for the compiler to elimininate the constant
|
||||
* branches.
|
||||
*/
|
||||
#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR
|
||||
/**
|
||||
* HINT_INLINE is used to help the compiler generate better code. It is *not*
|
||||
* used for "templates", so it can be tweaked based on the compilers
|
||||
* performance.
|
||||
*
|
||||
* gcc-4.8 and gcc-4.9 have been shown to benefit from leaving off the
|
||||
* always_inline attribute.
|
||||
*
|
||||
* clang up to 5.0.0 (trunk) benefit tremendously from the always_inline
|
||||
* attribute.
|
||||
*/
|
||||
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5
|
||||
# define HINT_INLINE static INLINE_KEYWORD
|
||||
#else
|
||||
# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR
|
||||
#endif
|
||||
|
||||
/* force no inlining */
|
||||
#ifdef _MSC_VER
|
||||
# define FORCE_NOINLINE static __declspec(noinline)
|
||||
#else
|
||||
# ifdef __GNUC__
|
||||
# define FORCE_NOINLINE static __attribute__((__noinline__))
|
||||
# else
|
||||
# define FORCE_NOINLINE static
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* target attribute */
|
||||
#ifndef __has_attribute
|
||||
#define __has_attribute(x) 0 /* Compatibility with non-clang compilers. */
|
||||
#endif
|
||||
#if defined(__GNUC__)
|
||||
# define TARGET_ATTRIBUTE(target) __attribute__((__target__(target)))
|
||||
#else
|
||||
# define TARGET_ATTRIBUTE(target)
|
||||
#endif
|
||||
|
||||
/* Enable runtime BMI2 dispatch based on the CPU.
|
||||
* Enabled for clang & gcc >=4.8 on x86 when BMI2 isn't enabled by default.
|
||||
*/
|
||||
#ifndef DYNAMIC_BMI2
|
||||
#if (defined(__clang__) && __has_attribute(__target__)) \
|
||||
|| (defined(__GNUC__) \
|
||||
&& (__GNUC__ >= 5 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) \
|
||||
&& (defined(__x86_64__) || defined(_M_X86)) \
|
||||
&& !defined(__BMI2__)
|
||||
# define DYNAMIC_BMI2 1
|
||||
#else
|
||||
# define DYNAMIC_BMI2 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/* prefetch */
|
||||
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_I86)) /* _mm_prefetch() is not defined outside of x86/x64 */
|
||||
# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
|
||||
# define PREFETCH(ptr) _mm_prefetch((const char*)ptr, _MM_HINT_T0)
|
||||
#elif defined(__GNUC__)
|
||||
# define PREFETCH(ptr) __builtin_prefetch(ptr, 0, 0)
|
||||
#else
|
||||
# define PREFETCH(ptr) /* disabled */
|
||||
#endif
|
||||
|
||||
/* disable warnings */
|
||||
#ifdef _MSC_VER /* Visual Studio */
|
||||
# include <intrin.h> /* For Visual 2005 */
|
||||
# pragma warning(disable : 4100) /* disable: C4100: unreferenced formal parameter */
|
||||
# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
|
||||
# pragma warning(disable : 4204) /* disable: C4204: non-constant aggregate initializer */
|
||||
# pragma warning(disable : 4214) /* disable: C4214: non-int bitfields */
|
||||
# pragma warning(disable : 4324) /* disable: C4324: padded structure */
|
||||
#endif
|
||||
|
||||
#endif /* ZSTD_COMPILER_H */
|
||||
1048
third/github.com/DataDog/zstd/cover.c
Normal file
1048
third/github.com/DataDog/zstd/cover.c
Normal file
File diff suppressed because it is too large
Load Diff
216
third/github.com/DataDog/zstd/cpu.h
Executable file
216
third/github.com/DataDog/zstd/cpu.h
Executable file
@ -0,0 +1,216 @@
|
||||
/*
|
||||
* Copyright (c) 2018-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
#ifndef ZSTD_COMMON_CPU_H
|
||||
#define ZSTD_COMMON_CPU_H
|
||||
|
||||
/**
|
||||
* Implementation taken from folly/CpuId.h
|
||||
* https://github.com/facebook/folly/blob/master/folly/CpuId.h
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "mem.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
U32 f1c;
|
||||
U32 f1d;
|
||||
U32 f7b;
|
||||
U32 f7c;
|
||||
} ZSTD_cpuid_t;
|
||||
|
||||
MEM_STATIC ZSTD_cpuid_t ZSTD_cpuid(void) {
|
||||
U32 f1c = 0;
|
||||
U32 f1d = 0;
|
||||
U32 f7b = 0;
|
||||
U32 f7c = 0;
|
||||
#ifdef _MSC_VER
|
||||
int reg[4];
|
||||
__cpuid((int*)reg, 0);
|
||||
{
|
||||
int const n = reg[0];
|
||||
if (n >= 1) {
|
||||
__cpuid((int*)reg, 1);
|
||||
f1c = (U32)reg[2];
|
||||
f1d = (U32)reg[3];
|
||||
}
|
||||
if (n >= 7) {
|
||||
__cpuidex((int*)reg, 7, 0);
|
||||
f7b = (U32)reg[1];
|
||||
f7c = (U32)reg[2];
|
||||
}
|
||||
}
|
||||
#elif defined(__i386__) && defined(__PIC__) && !defined(__clang__) && defined(__GNUC__)
|
||||
/* The following block like the normal cpuid branch below, but gcc
|
||||
* reserves ebx for use of its pic register so we must specially
|
||||
* handle the save and restore to avoid clobbering the register
|
||||
*/
|
||||
U32 n;
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"popl %%ebx\n\t"
|
||||
: "=a"(n)
|
||||
: "a"(0)
|
||||
: "ecx", "edx");
|
||||
if (n >= 1) {
|
||||
U32 f1a;
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"popl %%ebx\n\t"
|
||||
: "=a"(f1a), "=c"(f1c), "=d"(f1d)
|
||||
: "a"(1)
|
||||
:);
|
||||
}
|
||||
if (n >= 7) {
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"movl %%ebx, %%eax\n\r"
|
||||
"popl %%ebx"
|
||||
: "=a"(f7b), "=c"(f7c)
|
||||
: "a"(7), "c"(0)
|
||||
: "edx");
|
||||
}
|
||||
#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386__)
|
||||
U32 n;
|
||||
__asm__("cpuid" : "=a"(n) : "a"(0) : "ebx", "ecx", "edx");
|
||||
if (n >= 1) {
|
||||
U32 f1a;
|
||||
__asm__("cpuid" : "=a"(f1a), "=c"(f1c), "=d"(f1d) : "a"(1) : "ebx");
|
||||
}
|
||||
if (n >= 7) {
|
||||
U32 f7a;
|
||||
__asm__("cpuid"
|
||||
: "=a"(f7a), "=b"(f7b), "=c"(f7c)
|
||||
: "a"(7), "c"(0)
|
||||
: "edx");
|
||||
}
|
||||
#endif
|
||||
{
|
||||
ZSTD_cpuid_t cpuid;
|
||||
cpuid.f1c = f1c;
|
||||
cpuid.f1d = f1d;
|
||||
cpuid.f7b = f7b;
|
||||
cpuid.f7c = f7c;
|
||||
return cpuid;
|
||||
}
|
||||
}
|
||||
|
||||
#define X(name, r, bit) \
|
||||
MEM_STATIC int ZSTD_cpuid_##name(ZSTD_cpuid_t const cpuid) { \
|
||||
return ((cpuid.r) & (1U << bit)) != 0; \
|
||||
}
|
||||
|
||||
/* cpuid(1): Processor Info and Feature Bits. */
|
||||
#define C(name, bit) X(name, f1c, bit)
|
||||
C(sse3, 0)
|
||||
C(pclmuldq, 1)
|
||||
C(dtes64, 2)
|
||||
C(monitor, 3)
|
||||
C(dscpl, 4)
|
||||
C(vmx, 5)
|
||||
C(smx, 6)
|
||||
C(eist, 7)
|
||||
C(tm2, 8)
|
||||
C(ssse3, 9)
|
||||
C(cnxtid, 10)
|
||||
C(fma, 12)
|
||||
C(cx16, 13)
|
||||
C(xtpr, 14)
|
||||
C(pdcm, 15)
|
||||
C(pcid, 17)
|
||||
C(dca, 18)
|
||||
C(sse41, 19)
|
||||
C(sse42, 20)
|
||||
C(x2apic, 21)
|
||||
C(movbe, 22)
|
||||
C(popcnt, 23)
|
||||
C(tscdeadline, 24)
|
||||
C(aes, 25)
|
||||
C(xsave, 26)
|
||||
C(osxsave, 27)
|
||||
C(avx, 28)
|
||||
C(f16c, 29)
|
||||
C(rdrand, 30)
|
||||
#undef C
|
||||
#define D(name, bit) X(name, f1d, bit)
|
||||
D(fpu, 0)
|
||||
D(vme, 1)
|
||||
D(de, 2)
|
||||
D(pse, 3)
|
||||
D(tsc, 4)
|
||||
D(msr, 5)
|
||||
D(pae, 6)
|
||||
D(mce, 7)
|
||||
D(cx8, 8)
|
||||
D(apic, 9)
|
||||
D(sep, 11)
|
||||
D(mtrr, 12)
|
||||
D(pge, 13)
|
||||
D(mca, 14)
|
||||
D(cmov, 15)
|
||||
D(pat, 16)
|
||||
D(pse36, 17)
|
||||
D(psn, 18)
|
||||
D(clfsh, 19)
|
||||
D(ds, 21)
|
||||
D(acpi, 22)
|
||||
D(mmx, 23)
|
||||
D(fxsr, 24)
|
||||
D(sse, 25)
|
||||
D(sse2, 26)
|
||||
D(ss, 27)
|
||||
D(htt, 28)
|
||||
D(tm, 29)
|
||||
D(pbe, 31)
|
||||
#undef D
|
||||
|
||||
/* cpuid(7): Extended Features. */
|
||||
#define B(name, bit) X(name, f7b, bit)
|
||||
B(bmi1, 3)
|
||||
B(hle, 4)
|
||||
B(avx2, 5)
|
||||
B(smep, 7)
|
||||
B(bmi2, 8)
|
||||
B(erms, 9)
|
||||
B(invpcid, 10)
|
||||
B(rtm, 11)
|
||||
B(mpx, 14)
|
||||
B(avx512f, 16)
|
||||
B(avx512dq, 17)
|
||||
B(rdseed, 18)
|
||||
B(adx, 19)
|
||||
B(smap, 20)
|
||||
B(avx512ifma, 21)
|
||||
B(pcommit, 22)
|
||||
B(clflushopt, 23)
|
||||
B(clwb, 24)
|
||||
B(avx512pf, 26)
|
||||
B(avx512er, 27)
|
||||
B(avx512cd, 28)
|
||||
B(sha, 29)
|
||||
B(avx512bw, 30)
|
||||
B(avx512vl, 31)
|
||||
#undef B
|
||||
#define C(name, bit) X(name, f7c, bit)
|
||||
C(prefetchwt1, 0)
|
||||
C(avx512vbmi, 1)
|
||||
#undef C
|
||||
|
||||
#undef X
|
||||
|
||||
#endif /* ZSTD_COMMON_CPU_H */
|
||||
1913
third/github.com/DataDog/zstd/divsufsort.c
Normal file
1913
third/github.com/DataDog/zstd/divsufsort.c
Normal file
File diff suppressed because it is too large
Load Diff
67
third/github.com/DataDog/zstd/divsufsort.h
Normal file
67
third/github.com/DataDog/zstd/divsufsort.h
Normal file
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* divsufsort.h for libdivsufsort-lite
|
||||
* Copyright (c) 2003-2008 Yuta Mori All Rights Reserved.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person
|
||||
* obtaining a copy of this software and associated documentation
|
||||
* files (the "Software"), to deal in the Software without
|
||||
* restriction, including without limitation the rights to use,
|
||||
* copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following
|
||||
* conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
||||
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef _DIVSUFSORT_H
|
||||
#define _DIVSUFSORT_H 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif /* __cplusplus */
|
||||
|
||||
|
||||
/*- Prototypes -*/
|
||||
|
||||
/**
|
||||
* Constructs the suffix array of a given string.
|
||||
* @param T [0..n-1] The input string.
|
||||
* @param SA [0..n-1] The output array of suffixes.
|
||||
* @param n The length of the given string.
|
||||
* @param openMP enables OpenMP optimization.
|
||||
* @return 0 if no error occurred, -1 or -2 otherwise.
|
||||
*/
|
||||
int
|
||||
divsufsort(const unsigned char *T, int *SA, int n, int openMP);
|
||||
|
||||
/**
|
||||
* Constructs the burrows-wheeler transformed string of a given string.
|
||||
* @param T [0..n-1] The input string.
|
||||
* @param U [0..n-1] The output string. (can be T)
|
||||
* @param A [0..n-1] The temporary array. (can be NULL)
|
||||
* @param n The length of the given string.
|
||||
* @param num_indexes The length of secondary indexes array. (can be NULL)
|
||||
* @param indexes The secondary indexes array. (can be NULL)
|
||||
* @param openMP enables OpenMP optimization.
|
||||
* @return The primary index if no error occurred, -1 or -2 otherwise.
|
||||
*/
|
||||
int
|
||||
divbwt(const unsigned char *T, unsigned char *U, int *A, int n, unsigned char * num_indexes, int * indexes, int openMP);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* extern "C" */
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#endif /* _DIVSUFSORT_H */
|
||||
221
third/github.com/DataDog/zstd/entropy_common.c
Normal file
221
third/github.com/DataDog/zstd/entropy_common.c
Normal file
@ -0,0 +1,221 @@
|
||||
/*
|
||||
Common functions of New Generation Entropy library
|
||||
Copyright (C) 2016, Yann Collet.
|
||||
|
||||
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
You can contact the author at :
|
||||
- FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy
|
||||
- Public forum : https://groups.google.com/forum/#!forum/lz4c
|
||||
*************************************************************************** */
|
||||
|
||||
/* *************************************
|
||||
* Dependencies
|
||||
***************************************/
|
||||
#include "mem.h"
|
||||
#include "error_private.h" /* ERR_*, ERROR */
|
||||
#define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */
|
||||
#include "fse.h"
|
||||
#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */
|
||||
#include "huf.h"
|
||||
|
||||
|
||||
/*=== Version ===*/
|
||||
unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
|
||||
|
||||
|
||||
/*=== Error Management ===*/
|
||||
unsigned FSE_isError(size_t code) { return ERR_isError(code); }
|
||||
const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
|
||||
|
||||
unsigned HUF_isError(size_t code) { return ERR_isError(code); }
|
||||
const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
|
||||
|
||||
|
||||
/*-**************************************************************
|
||||
* FSE NCount encoding-decoding
|
||||
****************************************************************/
|
||||
size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
|
||||
const void* headerBuffer, size_t hbSize)
|
||||
{
|
||||
const BYTE* const istart = (const BYTE*) headerBuffer;
|
||||
const BYTE* const iend = istart + hbSize;
|
||||
const BYTE* ip = istart;
|
||||
int nbBits;
|
||||
int remaining;
|
||||
int threshold;
|
||||
U32 bitStream;
|
||||
int bitCount;
|
||||
unsigned charnum = 0;
|
||||
int previous0 = 0;
|
||||
|
||||
if (hbSize < 4) return ERROR(srcSize_wrong);
|
||||
bitStream = MEM_readLE32(ip);
|
||||
nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */
|
||||
if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
|
||||
bitStream >>= 4;
|
||||
bitCount = 4;
|
||||
*tableLogPtr = nbBits;
|
||||
remaining = (1<<nbBits)+1;
|
||||
threshold = 1<<nbBits;
|
||||
nbBits++;
|
||||
|
||||
while ((remaining>1) & (charnum<=*maxSVPtr)) {
|
||||
if (previous0) {
|
||||
unsigned n0 = charnum;
|
||||
while ((bitStream & 0xFFFF) == 0xFFFF) {
|
||||
n0 += 24;
|
||||
if (ip < iend-5) {
|
||||
ip += 2;
|
||||
bitStream = MEM_readLE32(ip) >> bitCount;
|
||||
} else {
|
||||
bitStream >>= 16;
|
||||
bitCount += 16;
|
||||
} }
|
||||
while ((bitStream & 3) == 3) {
|
||||
n0 += 3;
|
||||
bitStream >>= 2;
|
||||
bitCount += 2;
|
||||
}
|
||||
n0 += bitStream & 3;
|
||||
bitCount += 2;
|
||||
if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall);
|
||||
while (charnum < n0) normalizedCounter[charnum++] = 0;
|
||||
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
|
||||
ip += bitCount>>3;
|
||||
bitCount &= 7;
|
||||
bitStream = MEM_readLE32(ip) >> bitCount;
|
||||
} else {
|
||||
bitStream >>= 2;
|
||||
} }
|
||||
{ int const max = (2*threshold-1) - remaining;
|
||||
int count;
|
||||
|
||||
if ((bitStream & (threshold-1)) < (U32)max) {
|
||||
count = bitStream & (threshold-1);
|
||||
bitCount += nbBits-1;
|
||||
} else {
|
||||
count = bitStream & (2*threshold-1);
|
||||
if (count >= threshold) count -= max;
|
||||
bitCount += nbBits;
|
||||
}
|
||||
|
||||
count--; /* extra accuracy */
|
||||
remaining -= count < 0 ? -count : count; /* -1 means +1 */
|
||||
normalizedCounter[charnum++] = (short)count;
|
||||
previous0 = !count;
|
||||
while (remaining < threshold) {
|
||||
nbBits--;
|
||||
threshold >>= 1;
|
||||
}
|
||||
|
||||
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
|
||||
ip += bitCount>>3;
|
||||
bitCount &= 7;
|
||||
} else {
|
||||
bitCount -= (int)(8 * (iend - 4 - ip));
|
||||
ip = iend - 4;
|
||||
}
|
||||
bitStream = MEM_readLE32(ip) >> (bitCount & 31);
|
||||
} } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */
|
||||
if (remaining != 1) return ERROR(corruption_detected);
|
||||
if (bitCount > 32) return ERROR(corruption_detected);
|
||||
*maxSVPtr = charnum-1;
|
||||
|
||||
ip += (bitCount+7)>>3;
|
||||
return ip-istart;
|
||||
}
|
||||
|
||||
|
||||
/*! HUF_readStats() :
|
||||
Read compact Huffman tree, saved by HUF_writeCTable().
|
||||
`huffWeight` is destination buffer.
|
||||
`rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
|
||||
@return : size read from `src` , or an error Code .
|
||||
Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
|
||||
*/
|
||||
size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
|
||||
U32* nbSymbolsPtr, U32* tableLogPtr,
|
||||
const void* src, size_t srcSize)
|
||||
{
|
||||
U32 weightTotal;
|
||||
const BYTE* ip = (const BYTE*) src;
|
||||
size_t iSize;
|
||||
size_t oSize;
|
||||
|
||||
if (!srcSize) return ERROR(srcSize_wrong);
|
||||
iSize = ip[0];
|
||||
/* memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */
|
||||
|
||||
if (iSize >= 128) { /* special header */
|
||||
oSize = iSize - 127;
|
||||
iSize = ((oSize+1)/2);
|
||||
if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
|
||||
if (oSize >= hwSize) return ERROR(corruption_detected);
|
||||
ip += 1;
|
||||
{ U32 n;
|
||||
for (n=0; n<oSize; n+=2) {
|
||||
huffWeight[n] = ip[n/2] >> 4;
|
||||
huffWeight[n+1] = ip[n/2] & 15;
|
||||
} } }
|
||||
else { /* header compressed with FSE (normal case) */
|
||||
FSE_DTable fseWorkspace[FSE_DTABLE_SIZE_U32(6)]; /* 6 is max possible tableLog for HUF header (maybe even 5, to be tested) */
|
||||
if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
|
||||
oSize = FSE_decompress_wksp(huffWeight, hwSize-1, ip+1, iSize, fseWorkspace, 6); /* max (hwSize-1) values decoded, as last one is implied */
|
||||
if (FSE_isError(oSize)) return oSize;
|
||||
}
|
||||
|
||||
/* collect weight stats */
|
||||
memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
|
||||
weightTotal = 0;
|
||||
{ U32 n; for (n=0; n<oSize; n++) {
|
||||
if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected);
|
||||
rankStats[huffWeight[n]]++;
|
||||
weightTotal += (1 << huffWeight[n]) >> 1;
|
||||
} }
|
||||
if (weightTotal == 0) return ERROR(corruption_detected);
|
||||
|
||||
/* get last non-null symbol weight (implied, total must be 2^n) */
|
||||
{ U32 const tableLog = BIT_highbit32(weightTotal) + 1;
|
||||
if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
|
||||
*tableLogPtr = tableLog;
|
||||
/* determine last weight */
|
||||
{ U32 const total = 1 << tableLog;
|
||||
U32 const rest = total - weightTotal;
|
||||
U32 const verif = 1 << BIT_highbit32(rest);
|
||||
U32 const lastWeight = BIT_highbit32(rest) + 1;
|
||||
if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */
|
||||
huffWeight[oSize] = (BYTE)lastWeight;
|
||||
rankStats[lastWeight]++;
|
||||
} }
|
||||
|
||||
/* check tree construction validity */
|
||||
if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */
|
||||
|
||||
/* results */
|
||||
*nbSymbolsPtr = (U32)(oSize+1);
|
||||
return iSize+1;
|
||||
}
|
||||
48
third/github.com/DataDog/zstd/error_private.c
Normal file
48
third/github.com/DataDog/zstd/error_private.c
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
/* The purpose of this file is to have a single list of error strings embedded in binary */
|
||||
|
||||
#include "error_private.h"
|
||||
|
||||
const char* ERR_getErrorString(ERR_enum code)
|
||||
{
|
||||
static const char* const notErrorCode = "Unspecified error code";
|
||||
switch( code )
|
||||
{
|
||||
case PREFIX(no_error): return "No error detected";
|
||||
case PREFIX(GENERIC): return "Error (generic)";
|
||||
case PREFIX(prefix_unknown): return "Unknown frame descriptor";
|
||||
case PREFIX(version_unsupported): return "Version not supported";
|
||||
case PREFIX(frameParameter_unsupported): return "Unsupported frame parameter";
|
||||
case PREFIX(frameParameter_windowTooLarge): return "Frame requires too much memory for decoding";
|
||||
case PREFIX(corruption_detected): return "Corrupted block detected";
|
||||
case PREFIX(checksum_wrong): return "Restored data doesn't match checksum";
|
||||
case PREFIX(parameter_unsupported): return "Unsupported parameter";
|
||||
case PREFIX(parameter_outOfBound): return "Parameter is out of bound";
|
||||
case PREFIX(init_missing): return "Context should be init first";
|
||||
case PREFIX(memory_allocation): return "Allocation error : not enough memory";
|
||||
case PREFIX(workSpace_tooSmall): return "workSpace buffer is not large enough";
|
||||
case PREFIX(stage_wrong): return "Operation not authorized at current processing stage";
|
||||
case PREFIX(tableLog_tooLarge): return "tableLog requires too much memory : unsupported";
|
||||
case PREFIX(maxSymbolValue_tooLarge): return "Unsupported max Symbol Value : too large";
|
||||
case PREFIX(maxSymbolValue_tooSmall): return "Specified maxSymbolValue is too small";
|
||||
case PREFIX(dictionary_corrupted): return "Dictionary is corrupted";
|
||||
case PREFIX(dictionary_wrong): return "Dictionary mismatch";
|
||||
case PREFIX(dictionaryCreation_failed): return "Cannot create Dictionary from provided samples";
|
||||
case PREFIX(dstSize_tooSmall): return "Destination buffer is too small";
|
||||
case PREFIX(srcSize_wrong): return "Src size is incorrect";
|
||||
/* following error codes are not stable and may be removed or changed in a future version */
|
||||
case PREFIX(frameIndex_tooLarge): return "Frame index is too large";
|
||||
case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking";
|
||||
case PREFIX(maxCode):
|
||||
default: return notErrorCode;
|
||||
}
|
||||
}
|
||||
76
third/github.com/DataDog/zstd/error_private.h
Normal file
76
third/github.com/DataDog/zstd/error_private.h
Normal file
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
/* Note : this module is expected to remain private, do not expose it */
|
||||
|
||||
#ifndef ERROR_H_MODULE
|
||||
#define ERROR_H_MODULE
|
||||
|
||||
#if defined (__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
|
||||
/* ****************************************
|
||||
* Dependencies
|
||||
******************************************/
|
||||
#include <stddef.h> /* size_t */
|
||||
#include "zstd_errors.h" /* enum list */
|
||||
|
||||
|
||||
/* ****************************************
|
||||
* Compiler-specific
|
||||
******************************************/
|
||||
#if defined(__GNUC__)
|
||||
# define ERR_STATIC static __attribute__((unused))
|
||||
#elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
|
||||
# define ERR_STATIC static inline
|
||||
#elif defined(_MSC_VER)
|
||||
# define ERR_STATIC static __inline
|
||||
#else
|
||||
# define ERR_STATIC static /* this version may generate warnings for unused static functions; disable the relevant warning */
|
||||
#endif
|
||||
|
||||
|
||||
/*-****************************************
|
||||
* Customization (error_public.h)
|
||||
******************************************/
|
||||
typedef ZSTD_ErrorCode ERR_enum;
|
||||
#define PREFIX(name) ZSTD_error_##name
|
||||
|
||||
|
||||
/*-****************************************
|
||||
* Error codes handling
|
||||
******************************************/
|
||||
#undef ERROR /* reported already defined on VS 2015 (Rich Geldreich) */
|
||||
#define ERROR(name) ZSTD_ERROR(name)
|
||||
#define ZSTD_ERROR(name) ((size_t)-PREFIX(name))
|
||||
|
||||
ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); }
|
||||
|
||||
ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) { if (!ERR_isError(code)) return (ERR_enum)0; return (ERR_enum) (0-code); }
|
||||
|
||||
|
||||
/*-****************************************
|
||||
* Error Strings
|
||||
******************************************/
|
||||
|
||||
const char* ERR_getErrorString(ERR_enum code); /* error_private.c */
|
||||
|
||||
ERR_STATIC const char* ERR_getErrorName(size_t code)
|
||||
{
|
||||
return ERR_getErrorString(ERR_getErrorCode(code));
|
||||
}
|
||||
|
||||
#if defined (__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* ERROR_H_MODULE */
|
||||
35
third/github.com/DataDog/zstd/errors.go
Normal file
35
third/github.com/DataDog/zstd/errors.go
Normal file
@ -0,0 +1,35 @@
|
||||
package zstd
|
||||
|
||||
/*
|
||||
#define ZSTD_STATIC_LINKING_ONLY
|
||||
#include "zstd.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// ErrorCode is an error returned by the zstd library.
|
||||
type ErrorCode int
|
||||
|
||||
// Error returns the error string given by zstd
|
||||
func (e ErrorCode) Error() string {
|
||||
return C.GoString(C.ZSTD_getErrorName(C.size_t(e)))
|
||||
}
|
||||
|
||||
func cIsError(code int) bool {
|
||||
return int(C.ZSTD_isError(C.size_t(code))) != 0
|
||||
}
|
||||
|
||||
// getError returns an error for the return code, or nil if it's not an error
|
||||
func getError(code int) error {
|
||||
if code < 0 && cIsError(code) {
|
||||
return ErrorCode(code)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDstSizeTooSmallError returns whether the error correspond to zstd standard sDstSizeTooSmall error
|
||||
func IsDstSizeTooSmallError(e error) bool {
|
||||
if e != nil && e.Error() == "Destination buffer is too small" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
29
third/github.com/DataDog/zstd/errors_test.go
Normal file
29
third/github.com/DataDog/zstd/errors_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrorUpperBound is the upper bound to error number, currently only used in test
|
||||
// If this needs to be updated, check in zstd_errors.h what the max is
|
||||
ErrorUpperBound = 1000
|
||||
)
|
||||
|
||||
// TestFindIsDstSizeTooSmallError tests that there is at least one error code that
|
||||
// corresponds to dst size too small
|
||||
func TestFindIsDstSizeTooSmallError(t *testing.T) {
|
||||
found := 0
|
||||
for i := -1; i > -ErrorUpperBound; i-- {
|
||||
e := ErrorCode(i)
|
||||
if IsDstSizeTooSmallError(e) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
|
||||
if found == 0 {
|
||||
t.Fatal("Couldn't find an error code for DstSizeTooSmall error, please make sure we didn't change the error string")
|
||||
} else if found > 1 {
|
||||
t.Fatal("IsDstSizeTooSmallError found multiple error codes matching, this shouldn't be the case")
|
||||
}
|
||||
}
|
||||
704
third/github.com/DataDog/zstd/fse.h
Normal file
704
third/github.com/DataDog/zstd/fse.h
Normal file
@ -0,0 +1,704 @@
|
||||
/* ******************************************************************
|
||||
FSE : Finite State Entropy codec
|
||||
Public Prototypes declaration
|
||||
Copyright (C) 2013-2016, Yann Collet.
|
||||
|
||||
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
You can contact the author at :
|
||||
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
|
||||
****************************************************************** */
|
||||
|
||||
#if defined (__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifndef FSE_H
|
||||
#define FSE_H
|
||||
|
||||
|
||||
/*-*****************************************
|
||||
* Dependencies
|
||||
******************************************/
|
||||
#include <stddef.h> /* size_t, ptrdiff_t */
|
||||
|
||||
|
||||
/*-*****************************************
|
||||
* FSE_PUBLIC_API : control library symbols visibility
|
||||
******************************************/
|
||||
#if defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) && defined(__GNUC__) && (__GNUC__ >= 4)
|
||||
# define FSE_PUBLIC_API __attribute__ ((visibility ("default")))
|
||||
#elif defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) /* Visual expected */
|
||||
# define FSE_PUBLIC_API __declspec(dllexport)
|
||||
#elif defined(FSE_DLL_IMPORT) && (FSE_DLL_IMPORT==1)
|
||||
# define FSE_PUBLIC_API __declspec(dllimport) /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/
|
||||
#else
|
||||
# define FSE_PUBLIC_API
|
||||
#endif
|
||||
|
||||
/*------ Version ------*/
|
||||
#define FSE_VERSION_MAJOR 0
|
||||
#define FSE_VERSION_MINOR 9
|
||||
#define FSE_VERSION_RELEASE 0
|
||||
|
||||
#define FSE_LIB_VERSION FSE_VERSION_MAJOR.FSE_VERSION_MINOR.FSE_VERSION_RELEASE
|
||||
#define FSE_QUOTE(str) #str
|
||||
#define FSE_EXPAND_AND_QUOTE(str) FSE_QUOTE(str)
|
||||
#define FSE_VERSION_STRING FSE_EXPAND_AND_QUOTE(FSE_LIB_VERSION)
|
||||
|
||||
#define FSE_VERSION_NUMBER (FSE_VERSION_MAJOR *100*100 + FSE_VERSION_MINOR *100 + FSE_VERSION_RELEASE)
|
||||
FSE_PUBLIC_API unsigned FSE_versionNumber(void); /**< library version number; to be used when checking dll version */
|
||||
|
||||
/*-****************************************
|
||||
* FSE simple functions
|
||||
******************************************/
|
||||
/*! FSE_compress() :
|
||||
Compress content of buffer 'src', of size 'srcSize', into destination buffer 'dst'.
|
||||
'dst' buffer must be already allocated. Compression runs faster is dstCapacity >= FSE_compressBound(srcSize).
|
||||
@return : size of compressed data (<= dstCapacity).
|
||||
Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!!
|
||||
if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression instead.
|
||||
if FSE_isError(return), compression failed (more details using FSE_getErrorName())
|
||||
*/
|
||||
FSE_PUBLIC_API size_t FSE_compress(void* dst, size_t dstCapacity,
|
||||
const void* src, size_t srcSize);
|
||||
|
||||
/*! FSE_decompress():
|
||||
Decompress FSE data from buffer 'cSrc', of size 'cSrcSize',
|
||||
into already allocated destination buffer 'dst', of size 'dstCapacity'.
|
||||
@return : size of regenerated data (<= maxDstSize),
|
||||
or an error code, which can be tested using FSE_isError() .
|
||||
|
||||
** Important ** : FSE_decompress() does not decompress non-compressible nor RLE data !!!
|
||||
Why ? : making this distinction requires a header.
|
||||
Header management is intentionally delegated to the user layer, which can better manage special cases.
|
||||
*/
|
||||
FSE_PUBLIC_API size_t FSE_decompress(void* dst, size_t dstCapacity,
|
||||
const void* cSrc, size_t cSrcSize);
|
||||
|
||||
|
||||
/*-*****************************************
|
||||
* Tool functions
|
||||
******************************************/
|
||||
FSE_PUBLIC_API size_t FSE_compressBound(size_t size); /* maximum compressed size */
|
||||
|
||||
/* Error Management */
|
||||
FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return value is an error code */
|
||||
FSE_PUBLIC_API const char* FSE_getErrorName(size_t code); /* provides error code string (useful for debugging) */
|
||||
|
||||
|
||||
/*-*****************************************
|
||||
* FSE advanced functions
|
||||
******************************************/
|
||||
/*! FSE_compress2() :
|
||||
Same as FSE_compress(), but allows the selection of 'maxSymbolValue' and 'tableLog'
|
||||
Both parameters can be defined as '0' to mean : use default value
|
||||
@return : size of compressed data
|
||||
Special values : if return == 0, srcData is not compressible => Nothing is stored within cSrc !!!
|
||||
if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression.
|
||||
if FSE_isError(return), it's an error code.
|
||||
*/
|
||||
FSE_PUBLIC_API size_t FSE_compress2 (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog);
|
||||
|
||||
|
||||
/*-*****************************************
|
||||
* FSE detailed API
|
||||
******************************************/
|
||||
/*!
|
||||
FSE_compress() does the following:
|
||||
1. count symbol occurrence from source[] into table count[]
|
||||
2. normalize counters so that sum(count[]) == Power_of_2 (2^tableLog)
|
||||
3. save normalized counters to memory buffer using writeNCount()
|
||||
4. build encoding table 'CTable' from normalized counters
|
||||
5. encode the data stream using encoding table 'CTable'
|
||||
|
||||
FSE_decompress() does the following:
|
||||
1. read normalized counters with readNCount()
|
||||
2. build decoding table 'DTable' from normalized counters
|
||||
3. decode the data stream using decoding table 'DTable'
|
||||
|
||||
The following API allows targeting specific sub-functions for advanced tasks.
|
||||
For example, it's possible to compress several blocks using the same 'CTable',
|
||||
or to save and provide normalized distribution using external method.
|
||||
*/
|
||||
|
||||
/* *** COMPRESSION *** */
|
||||
|
||||
/*! FSE_count():
|
||||
Provides the precise count of each byte within a table 'count'.
|
||||
'count' is a table of unsigned int, of minimum size (*maxSymbolValuePtr+1).
|
||||
*maxSymbolValuePtr will be updated if detected smaller than initial value.
|
||||
@return : the count of the most frequent symbol (which is not identified).
|
||||
if return == srcSize, there is only one symbol.
|
||||
Can also return an error code, which can be tested with FSE_isError(). */
|
||||
FSE_PUBLIC_API size_t FSE_count(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
|
||||
|
||||
/*! FSE_optimalTableLog():
|
||||
dynamically downsize 'tableLog' when conditions are met.
|
||||
It saves CPU time, by using smaller tables, while preserving or even improving compression ratio.
|
||||
@return : recommended tableLog (necessarily <= 'maxTableLog') */
|
||||
FSE_PUBLIC_API unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue);
|
||||
|
||||
/*! FSE_normalizeCount():
|
||||
normalize counts so that sum(count[]) == Power_of_2 (2^tableLog)
|
||||
'normalizedCounter' is a table of short, of minimum size (maxSymbolValue+1).
|
||||
@return : tableLog,
|
||||
or an errorCode, which can be tested using FSE_isError() */
|
||||
FSE_PUBLIC_API size_t FSE_normalizeCount(short* normalizedCounter, unsigned tableLog, const unsigned* count, size_t srcSize, unsigned maxSymbolValue);
|
||||
|
||||
/*! FSE_NCountWriteBound():
|
||||
Provides the maximum possible size of an FSE normalized table, given 'maxSymbolValue' and 'tableLog'.
|
||||
Typically useful for allocation purpose. */
|
||||
FSE_PUBLIC_API size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog);
|
||||
|
||||
/*! FSE_writeNCount():
|
||||
Compactly save 'normalizedCounter' into 'buffer'.
|
||||
@return : size of the compressed table,
|
||||
or an errorCode, which can be tested using FSE_isError(). */
|
||||
FSE_PUBLIC_API size_t FSE_writeNCount (void* buffer, size_t bufferSize, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
|
||||
|
||||
|
||||
/*! Constructor and Destructor of FSE_CTable.
|
||||
Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */
|
||||
typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */
|
||||
FSE_PUBLIC_API FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog);
|
||||
FSE_PUBLIC_API void FSE_freeCTable (FSE_CTable* ct);
|
||||
|
||||
/*! FSE_buildCTable():
|
||||
Builds `ct`, which must be already allocated, using FSE_createCTable().
|
||||
@return : 0, or an errorCode, which can be tested using FSE_isError() */
|
||||
FSE_PUBLIC_API size_t FSE_buildCTable(FSE_CTable* ct, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
|
||||
|
||||
/*! FSE_compress_usingCTable():
|
||||
Compress `src` using `ct` into `dst` which must be already allocated.
|
||||
@return : size of compressed data (<= `dstCapacity`),
|
||||
or 0 if compressed data could not fit into `dst`,
|
||||
or an errorCode, which can be tested using FSE_isError() */
|
||||
FSE_PUBLIC_API size_t FSE_compress_usingCTable (void* dst, size_t dstCapacity, const void* src, size_t srcSize, const FSE_CTable* ct);
|
||||
|
||||
/*!
|
||||
Tutorial :
|
||||
----------
|
||||
The first step is to count all symbols. FSE_count() does this job very fast.
|
||||
Result will be saved into 'count', a table of unsigned int, which must be already allocated, and have 'maxSymbolValuePtr[0]+1' cells.
|
||||
'src' is a table of bytes of size 'srcSize'. All values within 'src' MUST be <= maxSymbolValuePtr[0]
|
||||
maxSymbolValuePtr[0] will be updated, with its real value (necessarily <= original value)
|
||||
FSE_count() will return the number of occurrence of the most frequent symbol.
|
||||
This can be used to know if there is a single symbol within 'src', and to quickly evaluate its compressibility.
|
||||
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()).
|
||||
|
||||
The next step is to normalize the frequencies.
|
||||
FSE_normalizeCount() will ensure that sum of frequencies is == 2 ^'tableLog'.
|
||||
It also guarantees a minimum of 1 to any Symbol with frequency >= 1.
|
||||
You can use 'tableLog'==0 to mean "use default tableLog value".
|
||||
If you are unsure of which tableLog value to use, you can ask FSE_optimalTableLog(),
|
||||
which will provide the optimal valid tableLog given sourceSize, maxSymbolValue, and a user-defined maximum (0 means "default").
|
||||
|
||||
The result of FSE_normalizeCount() will be saved into a table,
|
||||
called 'normalizedCounter', which is a table of signed short.
|
||||
'normalizedCounter' must be already allocated, and have at least 'maxSymbolValue+1' cells.
|
||||
The return value is tableLog if everything proceeded as expected.
|
||||
It is 0 if there is a single symbol within distribution.
|
||||
If there is an error (ex: invalid tableLog value), the function will return an ErrorCode (which can be tested using FSE_isError()).
|
||||
|
||||
'normalizedCounter' can be saved in a compact manner to a memory area using FSE_writeNCount().
|
||||
'buffer' must be already allocated.
|
||||
For guaranteed success, buffer size must be at least FSE_headerBound().
|
||||
The result of the function is the number of bytes written into 'buffer'.
|
||||
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError(); ex : buffer size too small).
|
||||
|
||||
'normalizedCounter' can then be used to create the compression table 'CTable'.
|
||||
The space required by 'CTable' must be already allocated, using FSE_createCTable().
|
||||
You can then use FSE_buildCTable() to fill 'CTable'.
|
||||
If there is an error, both functions will return an ErrorCode (which can be tested using FSE_isError()).
|
||||
|
||||
'CTable' can then be used to compress 'src', with FSE_compress_usingCTable().
|
||||
Similar to FSE_count(), the convention is that 'src' is assumed to be a table of char of size 'srcSize'
|
||||
The function returns the size of compressed data (without header), necessarily <= `dstCapacity`.
|
||||
If it returns '0', compressed data could not fit into 'dst'.
|
||||
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()).
|
||||
*/
|
||||
|
||||
|
||||
/* *** DECOMPRESSION *** */
|
||||
|
||||
/*! FSE_readNCount():
|
||||
Read compactly saved 'normalizedCounter' from 'rBuffer'.
|
||||
@return : size read from 'rBuffer',
|
||||
or an errorCode, which can be tested using FSE_isError().
|
||||
maxSymbolValuePtr[0] and tableLogPtr[0] will also be updated with their respective values */
|
||||
FSE_PUBLIC_API size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSymbolValuePtr, unsigned* tableLogPtr, const void* rBuffer, size_t rBuffSize);
|
||||
|
||||
/*! Constructor and Destructor of FSE_DTable.
|
||||
Note that its size depends on 'tableLog' */
|
||||
typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */
|
||||
FSE_PUBLIC_API FSE_DTable* FSE_createDTable(unsigned tableLog);
|
||||
FSE_PUBLIC_API void FSE_freeDTable(FSE_DTable* dt);
|
||||
|
||||
/*! FSE_buildDTable():
|
||||
Builds 'dt', which must be already allocated, using FSE_createDTable().
|
||||
return : 0, or an errorCode, which can be tested using FSE_isError() */
|
||||
FSE_PUBLIC_API size_t FSE_buildDTable (FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
|
||||
|
||||
/*! FSE_decompress_usingDTable():
|
||||
Decompress compressed source `cSrc` of size `cSrcSize` using `dt`
|
||||
into `dst` which must be already allocated.
|
||||
@return : size of regenerated data (necessarily <= `dstCapacity`),
|
||||
or an errorCode, which can be tested using FSE_isError() */
|
||||
FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, const FSE_DTable* dt);
|
||||
|
||||
/*!
|
||||
Tutorial :
|
||||
----------
|
||||
(Note : these functions only decompress FSE-compressed blocks.
|
||||
If block is uncompressed, use memcpy() instead
|
||||
If block is a single repeated byte, use memset() instead )
|
||||
|
||||
The first step is to obtain the normalized frequencies of symbols.
|
||||
This can be performed by FSE_readNCount() if it was saved using FSE_writeNCount().
|
||||
'normalizedCounter' must be already allocated, and have at least 'maxSymbolValuePtr[0]+1' cells of signed short.
|
||||
In practice, that means it's necessary to know 'maxSymbolValue' beforehand,
|
||||
or size the table to handle worst case situations (typically 256).
|
||||
FSE_readNCount() will provide 'tableLog' and 'maxSymbolValue'.
|
||||
The result of FSE_readNCount() is the number of bytes read from 'rBuffer'.
|
||||
Note that 'rBufferSize' must be at least 4 bytes, even if useful information is less than that.
|
||||
If there is an error, the function will return an error code, which can be tested using FSE_isError().
|
||||
|
||||
The next step is to build the decompression tables 'FSE_DTable' from 'normalizedCounter'.
|
||||
This is performed by the function FSE_buildDTable().
|
||||
The space required by 'FSE_DTable' must be already allocated using FSE_createDTable().
|
||||
If there is an error, the function will return an error code, which can be tested using FSE_isError().
|
||||
|
||||
`FSE_DTable` can then be used to decompress `cSrc`, with FSE_decompress_usingDTable().
|
||||
`cSrcSize` must be strictly correct, otherwise decompression will fail.
|
||||
FSE_decompress_usingDTable() result will tell how many bytes were regenerated (<=`dstCapacity`).
|
||||
If there is an error, the function will return an error code, which can be tested using FSE_isError(). (ex: dst buffer too small)
|
||||
*/
|
||||
|
||||
#endif /* FSE_H */
|
||||
|
||||
#if defined(FSE_STATIC_LINKING_ONLY) && !defined(FSE_H_FSE_STATIC_LINKING_ONLY)
|
||||
#define FSE_H_FSE_STATIC_LINKING_ONLY
|
||||
|
||||
/* *** Dependency *** */
|
||||
#include "bitstream.h"
|
||||
|
||||
|
||||
/* *****************************************
|
||||
* Static allocation
|
||||
*******************************************/
|
||||
/* FSE buffer bounds */
|
||||
#define FSE_NCOUNTBOUND 512
|
||||
#define FSE_BLOCKBOUND(size) (size + (size>>7))
|
||||
#define FSE_COMPRESSBOUND(size) (FSE_NCOUNTBOUND + FSE_BLOCKBOUND(size)) /* Macro version, useful for static allocation */
|
||||
|
||||
/* It is possible to statically allocate FSE CTable/DTable as a table of FSE_CTable/FSE_DTable using below macros */
|
||||
#define FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) (1 + (1<<(maxTableLog-1)) + ((maxSymbolValue+1)*2))
|
||||
#define FSE_DTABLE_SIZE_U32(maxTableLog) (1 + (1<<maxTableLog))
|
||||
|
||||
/* or use the size to malloc() space directly. Pay attention to alignment restrictions though */
|
||||
#define FSE_CTABLE_SIZE(maxTableLog, maxSymbolValue) (FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(FSE_CTable))
|
||||
#define FSE_DTABLE_SIZE(maxTableLog) (FSE_DTABLE_SIZE_U32(maxTableLog) * sizeof(FSE_DTable))
|
||||
|
||||
|
||||
/* *****************************************
|
||||
* FSE advanced API
|
||||
*******************************************/
|
||||
/* FSE_count_wksp() :
|
||||
* Same as FSE_count(), but using an externally provided scratch buffer.
|
||||
* `workSpace` size must be table of >= `1024` unsigned
|
||||
*/
|
||||
size_t FSE_count_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
|
||||
const void* source, size_t sourceSize, unsigned* workSpace);
|
||||
|
||||
/** FSE_countFast() :
|
||||
* same as FSE_count(), but blindly trusts that all byte values within src are <= *maxSymbolValuePtr
|
||||
*/
|
||||
size_t FSE_countFast(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
|
||||
|
||||
/* FSE_countFast_wksp() :
|
||||
* Same as FSE_countFast(), but using an externally provided scratch buffer.
|
||||
* `workSpace` must be a table of minimum `1024` unsigned
|
||||
*/
|
||||
size_t FSE_countFast_wksp(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* workSpace);
|
||||
|
||||
/*! FSE_count_simple() :
|
||||
* Same as FSE_countFast(), but does not use any additional memory (not even on stack).
|
||||
* This function is unsafe, and will segfault if any value within `src` is `> *maxSymbolValuePtr` (presuming it's also the size of `count`).
|
||||
*/
|
||||
size_t FSE_count_simple(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
|
||||
|
||||
|
||||
|
||||
unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus);
|
||||
/**< same as FSE_optimalTableLog(), which used `minus==2` */
|
||||
|
||||
/* FSE_compress_wksp() :
|
||||
* Same as FSE_compress2(), but using an externally allocated scratch buffer (`workSpace`).
|
||||
* FSE_WKSP_SIZE_U32() provides the minimum size required for `workSpace` as a table of FSE_CTable.
|
||||
*/
|
||||
#define FSE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) ( FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) + ((maxTableLog > 12) ? (1 << (maxTableLog - 2)) : 1024) )
|
||||
size_t FSE_compress_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize);
|
||||
|
||||
size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits);
|
||||
/**< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */
|
||||
|
||||
size_t FSE_buildCTable_rle (FSE_CTable* ct, unsigned char symbolValue);
|
||||
/**< build a fake FSE_CTable, designed to compress always the same symbolValue */
|
||||
|
||||
/* FSE_buildCTable_wksp() :
|
||||
* Same as FSE_buildCTable(), but using an externally allocated scratch buffer (`workSpace`).
|
||||
* `wkspSize` must be >= `(1<<tableLog)`.
|
||||
*/
|
||||
size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize);
|
||||
|
||||
size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits);
|
||||
/**< build a fake FSE_DTable, designed to read a flat distribution where each symbol uses nbBits */
|
||||
|
||||
size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue);
|
||||
/**< build a fake FSE_DTable, designed to always generate the same symbolValue */
|
||||
|
||||
size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, FSE_DTable* workSpace, unsigned maxLog);
|
||||
/**< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DTABLE_SIZE_U32(maxLog)` */
|
||||
|
||||
typedef enum {
|
||||
FSE_repeat_none, /**< Cannot use the previous table */
|
||||
FSE_repeat_check, /**< Can use the previous table but it must be checked */
|
||||
FSE_repeat_valid /**< Can use the previous table and it is asumed to be valid */
|
||||
} FSE_repeat;
|
||||
|
||||
/* *****************************************
|
||||
* FSE symbol compression API
|
||||
*******************************************/
|
||||
/*!
|
||||
This API consists of small unitary functions, which highly benefit from being inlined.
|
||||
Hence their body are included in next section.
|
||||
*/
|
||||
typedef struct {
|
||||
ptrdiff_t value;
|
||||
const void* stateTable;
|
||||
const void* symbolTT;
|
||||
unsigned stateLog;
|
||||
} FSE_CState_t;
|
||||
|
||||
static void FSE_initCState(FSE_CState_t* CStatePtr, const FSE_CTable* ct);
|
||||
|
||||
static void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* CStatePtr, unsigned symbol);
|
||||
|
||||
static void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* CStatePtr);
|
||||
|
||||
/**<
|
||||
These functions are inner components of FSE_compress_usingCTable().
|
||||
They allow the creation of custom streams, mixing multiple tables and bit sources.
|
||||
|
||||
A key property to keep in mind is that encoding and decoding are done **in reverse direction**.
|
||||
So the first symbol you will encode is the last you will decode, like a LIFO stack.
|
||||
|
||||
You will need a few variables to track your CStream. They are :
|
||||
|
||||
FSE_CTable ct; // Provided by FSE_buildCTable()
|
||||
BIT_CStream_t bitStream; // bitStream tracking structure
|
||||
FSE_CState_t state; // State tracking structure (can have several)
|
||||
|
||||
|
||||
The first thing to do is to init bitStream and state.
|
||||
size_t errorCode = BIT_initCStream(&bitStream, dstBuffer, maxDstSize);
|
||||
FSE_initCState(&state, ct);
|
||||
|
||||
Note that BIT_initCStream() can produce an error code, so its result should be tested, using FSE_isError();
|
||||
You can then encode your input data, byte after byte.
|
||||
FSE_encodeSymbol() outputs a maximum of 'tableLog' bits at a time.
|
||||
Remember decoding will be done in reverse direction.
|
||||
FSE_encodeByte(&bitStream, &state, symbol);
|
||||
|
||||
At any time, you can also add any bit sequence.
|
||||
Note : maximum allowed nbBits is 25, for compatibility with 32-bits decoders
|
||||
BIT_addBits(&bitStream, bitField, nbBits);
|
||||
|
||||
The above methods don't commit data to memory, they just store it into local register, for speed.
|
||||
Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t).
|
||||
Writing data to memory is a manual operation, performed by the flushBits function.
|
||||
BIT_flushBits(&bitStream);
|
||||
|
||||
Your last FSE encoding operation shall be to flush your last state value(s).
|
||||
FSE_flushState(&bitStream, &state);
|
||||
|
||||
Finally, you must close the bitStream.
|
||||
The function returns the size of CStream in bytes.
|
||||
If data couldn't fit into dstBuffer, it will return a 0 ( == not compressible)
|
||||
If there is an error, it returns an errorCode (which can be tested using FSE_isError()).
|
||||
size_t size = BIT_closeCStream(&bitStream);
|
||||
*/
|
||||
|
||||
|
||||
/* *****************************************
|
||||
* FSE symbol decompression API
|
||||
*******************************************/
|
||||
typedef struct {
|
||||
size_t state;
|
||||
const void* table; /* precise table may vary, depending on U16 */
|
||||
} FSE_DState_t;
|
||||
|
||||
|
||||
static void FSE_initDState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD, const FSE_DTable* dt);
|
||||
|
||||
static unsigned char FSE_decodeSymbol(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD);
|
||||
|
||||
static unsigned FSE_endOfDState(const FSE_DState_t* DStatePtr);
|
||||
|
||||
/**<
|
||||
Let's now decompose FSE_decompress_usingDTable() into its unitary components.
|
||||
You will decode FSE-encoded symbols from the bitStream,
|
||||
and also any other bitFields you put in, **in reverse order**.
|
||||
|
||||
You will need a few variables to track your bitStream. They are :
|
||||
|
||||
BIT_DStream_t DStream; // Stream context
|
||||
FSE_DState_t DState; // State context. Multiple ones are possible
|
||||
FSE_DTable* DTablePtr; // Decoding table, provided by FSE_buildDTable()
|
||||
|
||||
The first thing to do is to init the bitStream.
|
||||
errorCode = BIT_initDStream(&DStream, srcBuffer, srcSize);
|
||||
|
||||
You should then retrieve your initial state(s)
|
||||
(in reverse flushing order if you have several ones) :
|
||||
errorCode = FSE_initDState(&DState, &DStream, DTablePtr);
|
||||
|
||||
You can then decode your data, symbol after symbol.
|
||||
For information the maximum number of bits read by FSE_decodeSymbol() is 'tableLog'.
|
||||
Keep in mind that symbols are decoded in reverse order, like a LIFO stack (last in, first out).
|
||||
unsigned char symbol = FSE_decodeSymbol(&DState, &DStream);
|
||||
|
||||
You can retrieve any bitfield you eventually stored into the bitStream (in reverse order)
|
||||
Note : maximum allowed nbBits is 25, for 32-bits compatibility
|
||||
size_t bitField = BIT_readBits(&DStream, nbBits);
|
||||
|
||||
All above operations only read from local register (which size depends on size_t).
|
||||
Refueling the register from memory is manually performed by the reload method.
|
||||
endSignal = FSE_reloadDStream(&DStream);
|
||||
|
||||
BIT_reloadDStream() result tells if there is still some more data to read from DStream.
|
||||
BIT_DStream_unfinished : there is still some data left into the DStream.
|
||||
BIT_DStream_endOfBuffer : Dstream reached end of buffer. Its container may no longer be completely filled.
|
||||
BIT_DStream_completed : Dstream reached its exact end, corresponding in general to decompression completed.
|
||||
BIT_DStream_tooFar : Dstream went too far. Decompression result is corrupted.
|
||||
|
||||
When reaching end of buffer (BIT_DStream_endOfBuffer), progress slowly, notably if you decode multiple symbols per loop,
|
||||
to properly detect the exact end of stream.
|
||||
After each decoded symbol, check if DStream is fully consumed using this simple test :
|
||||
BIT_reloadDStream(&DStream) >= BIT_DStream_completed
|
||||
|
||||
When it's done, verify decompression is fully completed, by checking both DStream and the relevant states.
|
||||
Checking if DStream has reached its end is performed by :
|
||||
BIT_endOfDStream(&DStream);
|
||||
Check also the states. There might be some symbols left there, if some high probability ones (>50%) are possible.
|
||||
FSE_endOfDState(&DState);
|
||||
*/
|
||||
|
||||
|
||||
/* *****************************************
|
||||
* FSE unsafe API
|
||||
*******************************************/
|
||||
static unsigned char FSE_decodeSymbolFast(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD);
|
||||
/* faster, but works only if nbBits is always >= 1 (otherwise, result will be corrupted) */
|
||||
|
||||
|
||||
/* *****************************************
|
||||
* Implementation of inlined functions
|
||||
*******************************************/
|
||||
typedef struct {
|
||||
int deltaFindState;
|
||||
U32 deltaNbBits;
|
||||
} FSE_symbolCompressionTransform; /* total 8 bytes */
|
||||
|
||||
MEM_STATIC void FSE_initCState(FSE_CState_t* statePtr, const FSE_CTable* ct)
|
||||
{
|
||||
const void* ptr = ct;
|
||||
const U16* u16ptr = (const U16*) ptr;
|
||||
const U32 tableLog = MEM_read16(ptr);
|
||||
statePtr->value = (ptrdiff_t)1<<tableLog;
|
||||
statePtr->stateTable = u16ptr+2;
|
||||
statePtr->symbolTT = ((const U32*)ct + 1 + (tableLog ? (1<<(tableLog-1)) : 1));
|
||||
statePtr->stateLog = tableLog;
|
||||
}
|
||||
|
||||
|
||||
/*! FSE_initCState2() :
|
||||
* Same as FSE_initCState(), but the first symbol to include (which will be the last to be read)
|
||||
* uses the smallest state value possible, saving the cost of this symbol */
|
||||
MEM_STATIC void FSE_initCState2(FSE_CState_t* statePtr, const FSE_CTable* ct, U32 symbol)
|
||||
{
|
||||
FSE_initCState(statePtr, ct);
|
||||
{ const FSE_symbolCompressionTransform symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol];
|
||||
const U16* stateTable = (const U16*)(statePtr->stateTable);
|
||||
U32 nbBitsOut = (U32)((symbolTT.deltaNbBits + (1<<15)) >> 16);
|
||||
statePtr->value = (nbBitsOut << 16) - symbolTT.deltaNbBits;
|
||||
statePtr->value = stateTable[(statePtr->value >> nbBitsOut) + symbolTT.deltaFindState];
|
||||
}
|
||||
}
|
||||
|
||||
MEM_STATIC void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* statePtr, U32 symbol)
|
||||
{
|
||||
FSE_symbolCompressionTransform const symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol];
|
||||
const U16* const stateTable = (const U16*)(statePtr->stateTable);
|
||||
U32 const nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16);
|
||||
BIT_addBits(bitC, statePtr->value, nbBitsOut);
|
||||
statePtr->value = stateTable[ (statePtr->value >> nbBitsOut) + symbolTT.deltaFindState];
|
||||
}
|
||||
|
||||
MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePtr)
|
||||
{
|
||||
BIT_addBits(bitC, statePtr->value, statePtr->stateLog);
|
||||
BIT_flushBits(bitC);
|
||||
}
|
||||
|
||||
|
||||
/* ====== Decompression ====== */
|
||||
|
||||
typedef struct {
|
||||
U16 tableLog;
|
||||
U16 fastMode;
|
||||
} FSE_DTableHeader; /* sizeof U32 */
|
||||
|
||||
typedef struct
|
||||
{
|
||||
unsigned short newState;
|
||||
unsigned char symbol;
|
||||
unsigned char nbBits;
|
||||
} FSE_decode_t; /* size == U32 */
|
||||
|
||||
MEM_STATIC void FSE_initDState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD, const FSE_DTable* dt)
|
||||
{
|
||||
const void* ptr = dt;
|
||||
const FSE_DTableHeader* const DTableH = (const FSE_DTableHeader*)ptr;
|
||||
DStatePtr->state = BIT_readBits(bitD, DTableH->tableLog);
|
||||
BIT_reloadDStream(bitD);
|
||||
DStatePtr->table = dt + 1;
|
||||
}
|
||||
|
||||
MEM_STATIC BYTE FSE_peekSymbol(const FSE_DState_t* DStatePtr)
|
||||
{
|
||||
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
|
||||
return DInfo.symbol;
|
||||
}
|
||||
|
||||
MEM_STATIC void FSE_updateState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
|
||||
{
|
||||
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
|
||||
U32 const nbBits = DInfo.nbBits;
|
||||
size_t const lowBits = BIT_readBits(bitD, nbBits);
|
||||
DStatePtr->state = DInfo.newState + lowBits;
|
||||
}
|
||||
|
||||
MEM_STATIC BYTE FSE_decodeSymbol(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
|
||||
{
|
||||
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
|
||||
U32 const nbBits = DInfo.nbBits;
|
||||
BYTE const symbol = DInfo.symbol;
|
||||
size_t const lowBits = BIT_readBits(bitD, nbBits);
|
||||
|
||||
DStatePtr->state = DInfo.newState + lowBits;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
/*! FSE_decodeSymbolFast() :
|
||||
unsafe, only works if no symbol has a probability > 50% */
|
||||
MEM_STATIC BYTE FSE_decodeSymbolFast(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
|
||||
{
|
||||
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
|
||||
U32 const nbBits = DInfo.nbBits;
|
||||
BYTE const symbol = DInfo.symbol;
|
||||
size_t const lowBits = BIT_readBitsFast(bitD, nbBits);
|
||||
|
||||
DStatePtr->state = DInfo.newState + lowBits;
|
||||
return symbol;
|
||||
}
|
||||
|
||||
MEM_STATIC unsigned FSE_endOfDState(const FSE_DState_t* DStatePtr)
|
||||
{
|
||||
return DStatePtr->state == 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
#ifndef FSE_COMMONDEFS_ONLY
|
||||
|
||||
/* **************************************************************
|
||||
* Tuning parameters
|
||||
****************************************************************/
|
||||
/*!MEMORY_USAGE :
|
||||
* Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.)
|
||||
* Increasing memory usage improves compression ratio
|
||||
* Reduced memory usage can improve speed, due to cache effect
|
||||
* Recommended max value is 14, for 16KB, which nicely fits into Intel x86 L1 cache */
|
||||
#ifndef FSE_MAX_MEMORY_USAGE
|
||||
# define FSE_MAX_MEMORY_USAGE 14
|
||||
#endif
|
||||
#ifndef FSE_DEFAULT_MEMORY_USAGE
|
||||
# define FSE_DEFAULT_MEMORY_USAGE 13
|
||||
#endif
|
||||
|
||||
/*!FSE_MAX_SYMBOL_VALUE :
|
||||
* Maximum symbol value authorized.
|
||||
* Required for proper stack allocation */
|
||||
#ifndef FSE_MAX_SYMBOL_VALUE
|
||||
# define FSE_MAX_SYMBOL_VALUE 255
|
||||
#endif
|
||||
|
||||
/* **************************************************************
|
||||
* template functions type & suffix
|
||||
****************************************************************/
|
||||
#define FSE_FUNCTION_TYPE BYTE
|
||||
#define FSE_FUNCTION_EXTENSION
|
||||
#define FSE_DECODE_TYPE FSE_decode_t
|
||||
|
||||
|
||||
#endif /* !FSE_COMMONDEFS_ONLY */
|
||||
|
||||
|
||||
/* ***************************************************************
|
||||
* Constants
|
||||
*****************************************************************/
|
||||
#define FSE_MAX_TABLELOG (FSE_MAX_MEMORY_USAGE-2)
|
||||
#define FSE_MAX_TABLESIZE (1U<<FSE_MAX_TABLELOG)
|
||||
#define FSE_MAXTABLESIZE_MASK (FSE_MAX_TABLESIZE-1)
|
||||
#define FSE_DEFAULT_TABLELOG (FSE_DEFAULT_MEMORY_USAGE-2)
|
||||
#define FSE_MIN_TABLELOG 5
|
||||
|
||||
#define FSE_TABLELOG_ABSOLUTE_MAX 15
|
||||
#if FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX
|
||||
# error "FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX is not supported"
|
||||
#endif
|
||||
|
||||
#define FSE_TABLESTEP(tableSize) ((tableSize>>1) + (tableSize>>3) + 3)
|
||||
|
||||
|
||||
#endif /* FSE_STATIC_LINKING_ONLY */
|
||||
|
||||
|
||||
#if defined (__cplusplus)
|
||||
}
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user