Compare commits

...

72 Commits

Author SHA1 Message Date
d604d198ab hot fix 2019-01-04 15:43:56 +08:00
36791d2f48 README updates 2019-01-02 21:43:17 +08:00
08f9cffed9 TODO++ 2019-01-02 21:38:26 +08:00
783c0ba846 updates kafka dependences 2019-01-02 21:35:48 +08:00
7ad4f61564 revert hot fix codes, waiting for next release to fix 2019-01-02 11:41:22 +08:00
adf06a2b0d thirdparty package kafka updated to date 2019-01-02 11:02:03 +08:00
d6aa2b2512 hot fix for gcache 2019-01-02 10:30:27 +08:00
0a8af94610 !15 打开文件没关闭
Merge pull request !15 from hello/master
2019-01-01 19:45:46 +08:00
wgb
2c27c0f58a close file 2019-01-01 15:35:25 +08:00
4172eae87e update default ConnMaxLifeTime to 30 seconds in gdb package 2018-12-28 22:02:21 +08:00
26f2c61068 update default ConnMaxLifeTime to 30 seconds in gdb package 2018-12-28 22:00:49 +08:00
f97bed2607 update default ConnMaxLifeTime to 10 seconds in gdb package 2018-12-28 21:56:27 +08:00
8ef7155c70 hot fix 2018-12-28 21:46:01 +08:00
2c6e8f88fb README updates, TODO++ 2018-12-27 20:47:13 +08:00
25068b1e83 README updates 2018-12-27 13:27:57 +08:00
1f36eb3a9a README updates 2018-12-27 13:27:04 +08:00
a9ed577d05 README updates 2018-12-27 09:59:56 +08:00
782d614082 README updates 2018-12-27 09:57:54 +08:00
0629c00b07 README updates 2018-12-27 09:49:50 +08:00
b90d5bb205 README updates 2018-12-27 09:46:53 +08:00
cbc824c80a README updates 2018-12-27 09:46:35 +08:00
0c9be40b86 README updates 2018-12-27 09:46:18 +08:00
c96abd706d README updates 2018-12-27 09:45:04 +08:00
0ae5872783 README updates 2018-12-27 09:22:07 +08:00
2cff10e0d2 fix issue in controller interface definition 2018-12-26 10:17:24 +08:00
cab78f557d fix issue in controller detection for object parameter, in router group of web server 2018-12-25 23:20:43 +08:00
04353aa1a5 RELEASE updates 2018-12-25 13:54:36 +08:00
35121a66e9 README updates 2018-12-22 21:50:47 +08:00
e726ed2c19 gdb.Model updates 2018-12-22 21:03:03 +08:00
503446afc7 fix issue of ghttp.Request.GetVar 2018-12-22 11:52:12 +08:00
2063f662d3 fix silly issue in binary search of garray package, and add unit-test file for garray 2018-12-20 21:55:05 +08:00
d7381399aa fix issue of grand.intn in x86 arch; add router group feature for WebServer 2018-12-20 21:04:43 +08:00
d05b497cdb Merge branch 'master' into qiangg_router_group 2018-12-19 18:58:47 +08:00
ef919be587 g.DB can use gdb's configurations, not to force using config.toml 2018-12-19 18:35:44 +08:00
fff31e0f4f add Charset support for mysql of gdb package; fix issue for glog for log writing failed when the folder path wa deleted 2018-12-19 18:15:22 +08:00
cdd6fc7c1e extend pid length from 16bit to 24bit in process communication of gproc package 2018-12-19 16:17:54 +08:00
74bc36a2dc remove gfile.MainPkgPath check in gcfg/gview default path 2018-12-19 14:51:09 +08:00
48328ae52c router group developing 2018-12-19 14:45:39 +08:00
a86f4f8e23 disable auto adding temp directory to gview/gcfg search path; disable backtrace feature in normal log print with glog; fix issue caused by fmt.Fprintf in gfsnotify 2018-12-18 20:03:23 +08:00
0a1e048268 add Model.Clone support for gdb package 2018-12-18 10:10:14 +08:00
6fc5efd6ba README updates 2018-12-17 20:51:49 +08:00
2d795b593d README updates 2018-12-17 20:44:38 +08:00
20628ec75c README updates 2018-12-17 19:50:35 +08:00
10d1ccb009 README updates 2018-12-17 19:41:08 +08:00
fcc37c9581 CI updates 2018-12-17 19:36:34 +08:00
43cd391543 CI updates 2018-12-17 19:29:03 +08:00
18d2df33f7 CI updates 2018-12-17 19:26:59 +08:00
a85daa5617 CI updates 2018-12-17 18:35:29 +08:00
48dc4ce3e2 travis updates 2018-12-17 14:01:43 +08:00
d07bac89a0 travis updates 2018-12-17 13:59:00 +08:00
5d32ad6bc4 travis updates 2018-12-17 13:57:15 +08:00
397b0a3e7e travis updates 2018-12-17 13:50:17 +08:00
259961632d travis updates 2018-12-17 13:46:06 +08:00
cb1d6382ec travis updates 2018-12-17 13:38:35 +08:00
8714a69a13 travis updates 2018-12-17 13:36:38 +08:00
3ae0ea2de7 travis updates 2018-12-17 13:32:35 +08:00
1879a9f4c7 README updates 2018-12-17 13:28:19 +08:00
3938717b04 travis updates 2018-12-17 13:24:53 +08:00
1208b688f1 add code helper 2018-12-17 13:08:32 +08:00
0ad7ee5a32 add code helper 2018-12-17 13:07:01 +08:00
7a4e68e6b9 add code helper 2018-12-17 13:06:58 +08:00
71222b247f add travis/goreport/golint/govet 2018-12-17 13:02:55 +08:00
95db811943 add travis/goreport/golint/govet 2018-12-17 13:02:18 +08:00
2dbc817132 VERSION updates 2018-12-17 11:24:58 +08:00
7a8bd96edc gdb: add support for slice argument in where statement 2018-12-17 10:52:44 +08:00
c5e9686a95 gdb updates, make priority=1 when no priority set 2018-12-16 23:11:15 +08:00
c914edf616 gdb comment updates 2018-12-16 22:27:04 +08:00
656bfcb6bd Merge branch 'qiangg_db2' into develop 2018-12-16 22:22:33 +08:00
7434dfe6fa done refacting gdb package 2018-12-16 22:22:07 +08:00
e67aa63a50 refract gdb package, add complete unit test cases, almost there 2018-12-15 15:50:39 +08:00
d5e46f2b42 refracting gdb package 2018-12-14 18:35:51 +08:00
09e6f10b60 new version of gdb developing 2018-12-14 10:09:45 +08:00
266 changed files with 61700 additions and 3258 deletions

33
.travis.yml Normal file
View File

@ -0,0 +1,33 @@
language: go
go:
- "1.11.x"
branches:
only:
- master
- develop
env:
- GITEE_GF=$GOPATH/src/gitee.com/johng/gf GO111MODULE=on
services:
- mysql
before_install:
- pwd
install:
- pwd
- mkdir -p $GITEE_GF
- cp * $GITEE_GF -R
- cd $GITEE_GF
script:
- cd g && go test -v ./... -race -coverprofile=coverage.txt -covermode=atomic
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -1,6 +1,18 @@
<div align=center>
<img src="https://gfer.me/cover.png" width="150"/>
</div>
# GoFrame
<img align="right" height="150px" src="https://gfer.me/cover.png">
[![Go Doc](https://godoc.org/github.com/johng-cn/gf?status.svg)](https://godoc.org/github.com/johng-cn/gf)
[![Build Status](https://travis-ci.org/johng-cn/gf.svg?branch=master)](https://travis-ci.org/johng-cn/gf)
[![Go Report](https://goreportcard.com/badge/github.com/johng-cn/gf)](https://goreportcard.com/report/github.com/johng-cn/gf)
[![Documents](https://img.shields.io/badge/docs-100%25-green.svg)](https://gfer.me)
[![License](https://img.shields.io/github/license/johng-cn/gf.svg?style=flat)](https://github.com/johng-cn/gf)
[![Language](https://img.shields.io/badge/language-go-blue.svg)](https://github.com/johng-cn/gf)
[![Release](https://img.shields.io/github/release/johng-cn/gf.svg?style=flat)](https://github.com/johng-cn/gf/releases)
<!--
[![Code Coverage](https://codecov.io/gh/johng-cn/gf/branch/master/graph/badge.svg)](https://codecov.io/gh/johng-cn/gf)
[![Code Helper](https://www.codetriage.com/johng-cn/gf/badges/users.svg)](https://www.codetriage.com/johng-cn/gf)
-->
`GF(GoFrame)` is a modular, lightweight, loosely coupled, high performance application development framework written in Go. Supporting graceful server, hot updates, multi-domain, multi-port, multi-service, HTTP/HTTPS, dynamic/hook routing and many more features. Providing a series of core components and dozens of practical modules.
@ -8,6 +20,11 @@
```
go get -u gitee.com/johng/gf
```
or use `go.mod`
```
require gitee.com/johng/gf latest
```
# Limitation
```
golang version >= 1.9.2
@ -41,6 +58,47 @@ func main() {
}
```
[View More..](https://gfer.me/start/index)
# License
GF is licensed under the [MIT License](LICENSE), 100% free and open-source.
`GF` is licensed under the [MIT License](LICENSE), 100% free and open-source, forever.
# Contributors(TOP 10)
<a href="https://gitee.com/johng" target="_blank" title="John"><img src="https://gitee.com/uploads/27/1309327_johng.png?1530630243" width="60" align="left"></a>
<a href="https://gitee.com/wenzi1" target="_blank" title="蚊子"><img src="https://images.gitee.com/uploads/22/1923122_wenzi1.png" width="60" align="left"></a>
<a href="https://gitee.com/zseeker" target="_blank" title="zseeker"><img src="https://gfer.me/images/contributors/zseeker.png" width="60" align="left"></a>
<a href="https://gitee.com/ymrjqyy" target="_blank" title="一墨染尽青衣颜"><img src="https://images.gitee.com/uploads/27/876827_ymrjqyy.png" width="60" align="left"></a>
<a href="https://github.com/chenyang351" target="_blank" title="chenyang351"><img src="https://avatars1.githubusercontent.com/u/30063958?s=60&v=4" width="60" align="left"></a>
<a href="https://gitee.com/wxkj" target="_blank" title="wxkj"><img src="https://gitee.com/uploads/56/91356_wxkj.png" width="60" align="left"></a>
<a href="https://github.com/wxkj001" target="_blank" title="3wxkj001
"><img src="https://avatars0.githubusercontent.com/u/7794279?s=60&v=4" width="60" align="left"></a>
<a href="https://gitee.com/zhangjinfu" target="_blank" title="张金富"><img src="https://images.gitee.com/uploads/63/356163_zhangjinfu.png" width="60" align="left"></a>
<a href="https://gitee.com/garfieldkwong" target="_blank" title="GarfieldKwong"><img src="https://gfer.me/images/contributors/garfieldkwong.png" width="60" align="left"></a>
<a href="https://gitee.com/qq1054000800" target="_blank" title="hello"><img src="https://gitee.com/uploads/9/2209_qq1054000800.jpg" width="60" align="left"></a>
<br /><br /><br />
# Donators
<a href="https://gitee.com/zfan_codes" target="_blank" title="范钟"><img src="https://images.gitee.com/uploads/32/2044832_zfan_codes.png" width="60" align="left"></a>
<a href="https://gitee.com/hailaz" target="_blank" title="HaiLaz"><img src="https://gitee.com/uploads/87/1273187_hailaz.png" width="60" align="left"></a>
<a href="https://gitee.com/mg91" target="_blank" title="mg91"><img src="https://images.gitee.com/uploads/30/1410930_mg91.png" width="60" align="left"></a>

View File

@ -1,18 +1,43 @@
<div align=center>
<img src="https://gfer.me/cover.png" width="150"/>
</div>
# GoFrame
<img align="right" height="150px" src="https://gfer.me/cover.png">
[![Go Doc](https://godoc.org/github.com/johng-cn/gf?status.svg)](https://godoc.org/github.com/johng-cn/gf)
[![Build Status](https://travis-ci.org/johng-cn/gf.svg?branch=master)](https://travis-ci.org/johng-cn/gf)
[![Go Report](https://goreportcard.com/badge/github.com/johng-cn/gf)](https://goreportcard.com/report/github.com/johng-cn/gf)
[![Documents](https://img.shields.io/badge/docs-100%25-green.svg)](https://gfer.me)
[![License](https://img.shields.io/github/license/johng-cn/gf.svg?style=flat)](https://github.com/johng-cn/gf)
[![Language](https://img.shields.io/badge/language-go-blue.svg)](https://github.com/johng-cn/gf)
[![Release](https://img.shields.io/github/release/johng-cn/gf.svg?style=flat)](https://github.com/johng-cn/gf/releases)
<!--
[![Code Coverage](https://codecov.io/gh/johng-cn/gf/branch/master/graph/badge.svg)](https://codecov.io/gh/johng-cn/gf)
[![Code Helper](https://www.codetriage.com/johng-cn/gf/badges/users.svg)](https://www.codetriage.com/johng-cn/gf)
-->
`GF(Go Frame)`是一款模块化、松耦合、轻量级、高性能的Go应用开发框架。支持热重启、热更新、多域名、多端口、多服务、HTTP/HTTPS、动态路由等特性
并提供了Web服务开发的系列核心组件Router、Cookie、Session、服务注册、配置管理、模板引擎、数据校验、分页管理、数据库ORM等等等等
并且提供了数十个内置核心开发模块集缓存、日志、时间、命令行、二进制、文件锁、内存锁、对象池、连接池、数据编码、进程管理、进程通信、文件监控、定时任务、TCP/UDP组件、
并发安全容器等等等等等等。
# 特点
* 模块化、松耦合设计;
* 丰富实用的开发模块;
* 详尽的开发文档及示例;
* 完善的本地中文化支持;
* 致力于项目的通用方案;
* 更适合企业及团队使用;
* 更多请查阅文档及源码;
# 安装
```html
go get -u gitee.com/johng/gf
```
或者
`go.mod`
```
require gitee.com/johng/gf latest
```
# 限制
```shell
golang版本 >= 1.9.2
@ -24,6 +49,7 @@ golang版本 >= 1.9.2
</div>
# 文档
[https://gfer.me](https://gfer.me)
@ -44,4 +70,35 @@ func main() {
})
s.Run()
}
```
```
[更多..](https://gfer.me/start/index)
# 协议
`GF` 使用非常友好的 [MIT](LICENSE) 开源协议进行发布,永久`100%`开源免费。
# 贡献者(TOP 10)
<a href="https://gitee.com/johng" target="_blank" title="John"><img src="https://gitee.com/uploads/27/1309327_johng.png" width="60" align="left"></a>
<a href="https://gitee.com/wenzi1" target="_blank" title="蚊子"><img src="https://images.gitee.com/uploads/22/1923122_wenzi1.png" width="60" align="left"></a>
<a href="https://gitee.com/zseeker" target="_blank" title="zseeker"><img src="https://gfer.me/images/contributors/zseeker.png" width="60" align="left"></a>
<a href="https://gitee.com/ymrjqyy" target="_blank" title="一墨染尽青衣颜"><img src="https://images.gitee.com/uploads/27/876827_ymrjqyy.png" width="60" align="left"></a>
<a href="https://github.com/chenyang351" target="_blank" title="chenyang351"><img src="https://avatars1.githubusercontent.com/u/30063958?s=60&v=4" width="60" align="left"></a>
<a href="https://gitee.com/wxkj" target="_blank" title="wxkj"><img src="https://gitee.com/uploads/56/91356_wxkj.png" width="60" align="left"></a>
<a href="https://github.com/wxkj001" target="_blank" title="3wxkj001
"><img src="https://avatars0.githubusercontent.com/u/7794279?s=60&v=4" width="60" align="left"></a>
<a href="https://gitee.com/zhangjinfu" target="_blank" title="张金富"><img src="https://images.gitee.com/uploads/63/356163_zhangjinfu.png" width="60" align="left"></a>
<a href="https://gitee.com/garfieldkwong" target="_blank" title="GarfieldKwong"><img src="https://gfer.me/images/contributors/garfieldkwong.png" width="60" align="left"></a>
<a href="https://gitee.com/qq1054000800" target="_blank" title="hello"><img src="https://gitee.com/uploads/9/2209_qq1054000800.jpg" width="60" align="left"></a>
<br /><br /><br />
# 捐赠者
<a href="https://gitee.com/zfan_codes" target="_blank" title="范钟"><img src="https://images.gitee.com/uploads/32/2044832_zfan_codes.png" width="60" align="left"></a>
<a href="https://gitee.com/hailaz" target="_blank" title="HaiLaz"><img src="https://gitee.com/uploads/87/1273187_hailaz.png" width="60" align="left"></a>
<a href="https://gitee.com/mg91" target="_blank" title="mg91"><img src="https://images.gitee.com/uploads/30/1410930_mg91.png" width="60" align="left"></a>

View File

@ -1,3 +1,40 @@
# `v1.3.8` (2018-12-26)
## 新特性
1. 对`gform`完成重构,以提高扩展性,并修复部分细节问题、完善单元测试用例([https://gfer.me/database/orm/index](https://gfer.me/database/orm/index))
1. `WebServer`路由注册新增分组路由特性([https://gfer.me/net/ghttp/group](https://gfer.me/net/ghttp/group));
1. `WebServer`新增`Rewrite`路由重写特性([https://gfer.me/net/ghttp/static](https://gfer.me/net/ghttp/static));
1. 增加框架运行时对开发环境的自动识别;
1. 增加了`Travis CI`自动化构建/测试;
## 新功能
1. 改进`WebServer`静态文件服务功能,增加`SetStaticPath`/`AddStaticPath`方法([https://gfer.me/net/ghttp/static](https://gfer.me/net/ghttp/static))
1. `gform`新增`Filter`链式操作方法,用于过滤参数中的非表字段键值对([https://gfer.me/database/orm/linkop](https://gfer.me/database/orm/linkop)
1. `gcache`新增`Data`方法,用以获取所有的缓存数据项;
1. `gredis`增加`GetConn`方法获取原生redis连接对象
## 功能改进
1. 改进`gform`的`Where`方法,支持`slice`类型的参数,并更方便地支持`in`操作查询([https://gfer.me/database/orm/linkop](https://gfer.me/database/orm/linkop)
1. 改进`gproc`进程间通信数据结构,将`pid`字段从`16bit`扩展为`24bit`
1. 改进`gconv`/`gmap`/`garray`,增加若干操作方法;
1. 改进`gview`模板引擎中的`date`内置函数,当给定的时间戳为空时打印当前的系统时间;
1. 改进`gview`模板引擎中,当打印的变量不存在时,显示为空(标准库默认显示为`<no value>`
1. 改进`WebServer`,去掉`HANGUP`的信号监听,避免程序通过`nohup`运行时产生异常退出问题;
1. 改进`gcache`性能,并完善基准测试;
## Bug Fix
1. 修复`gcache`在非LRU特性开启时的缓存关闭资源竞争问题并修复`doSetWithLockCheck`内部方法的返回值问题;
1. 修复`grand.intn`内部方法在`x86`架构下的随机数位溢出问题;
1. 修复`gbinary`中`Int`方法针对`[]byte`参数长度自动匹配造成的字节长度溢出问题;
1. 修复`gjson`由于官方标准库`json`不支持`map[interface{}]*`类型造成的Go变量编码问题
1. 修复`garray`中部分方法的数据竞争问题,修复二分查找排序问题;
1. 修复`ghttp.Request.GetVar`方法获取参数问题;
1. 修复`gform`的数据库连接池不起作用的问题;
# `v1.2.11` (2018-11-26)
## 新特性
1. `ORM`新增对`SQLServer`及`Oracle`的支持([https://gfer.me/database/orm/database](https://gfer.me/database/orm/database))

View File

@ -42,6 +42,14 @@
1. gtcp提供简便的包发送/接收方法(SendPkg/RecvPkg)以解决常见的TCP通信粘包问题并完善文档参考https://www.cnblogs.com/kex1n/p/6502002.html
1. gfile对于文件的读写强行使用了gfpool在某些场景下不合适需要考虑剥离开并为开发者提供单独的指针池文件操作特性
1. 路由增加不区分大小写得匹配方式;
1. str_ireplace: http://php.net/manual/en/function.str-ireplace.php
1. strpos/stripos/strrpos/strripos: http://php.net/manual/en/function.stripos.php
1. 改进WebServer获取POST参数处理逻辑当提交非form数据时例如json数据针对某些方法可以直接解析
1. WebServer增加可选择的路由覆盖配置默认情况下不覆盖
1. gkafka这个包比较重未来从框架中剥离出来

View File

@ -153,16 +153,14 @@ func (a *SortedIntArray) binSearch(value int, lock bool) (index int, result int)
max := len(a.array) - 1
mid := 0
cmp := -2
for {
for min <= max {
mid = int((min + max) / 2)
cmp = a.compareFunc(value, a.array[mid])
switch cmp {
case -1 : max = mid - 1
case 0 :
case 1 : min = mid + 1
}
if cmp == 0 || min >= max {
break
case 0 :
return mid, cmp
}
}
return mid, cmp

View File

@ -146,16 +146,14 @@ func (a *SortedArray) binSearch(value interface{}, lock bool)(index int, result
max := len(a.array) - 1
mid := 0
cmp := -2
for {
for min <= max {
mid = int((min + max) / 2)
cmp = a.compareFunc(value, a.array[mid])
switch cmp {
case -1 : max = mid - 1
case 0 :
case 1 : min = mid + 1
}
if cmp == 0 || min >= max {
break
case 0 :
return mid, cmp
}
}
return mid, cmp

View File

@ -147,16 +147,14 @@ func (a *SortedStringArray) binSearch(value string, lock bool) (index int, resul
max := len(a.array) - 1
mid := 0
cmp := -2
for {
for min <= max {
mid = int((min + max) / 2)
cmp = a.compareFunc(value, a.array[mid])
switch cmp {
case -1 : max = mid - 1
case 0 :
case 1 : min = mid + 1
}
if cmp == 0 || min >= max {
break
case 0 :
return mid, cmp
}
}
return mid, cmp

View File

@ -9,18 +9,76 @@
package garray_test
import (
"fmt"
"gitee.com/johng/gf/g/container/garray"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gtest"
"strings"
"testing"
)
func TestArray_Unique(t *testing.T) {
func Test_IntArray_Unique(t *testing.T) {
expect := []int{1, 2, 3, 4, 5, 6}
array := garray.NewIntArray(0, 0)
array.Append(1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6)
array.Unique()
if fmt.Sprint(array.Slice()) != fmt.Sprint(expect) {
t.Errorf("get: %v, expect: %v\n", array.Slice(), expect)
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedIntArray1(t *testing.T) {
expect := []int{0,1,2,3,4,5,6,7,8,9,10}
array := garray.NewSortedIntArray(0)
for i := 10; i > -1; i-- {
array.Add(i)
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedIntArray2(t *testing.T) {
expect := []int{0,1,2,3,4,5,6,7,8,9,10}
array := garray.NewSortedIntArray(0)
for i := 0; i <= 10; i++ {
array.Add(i)
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedStringArray1(t *testing.T) {
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
array := garray.NewSortedStringArray(0)
for i := 10; i > -1; i-- {
array.Add(gconv.String(i))
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedStringArray2(t *testing.T) {
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
array := garray.NewSortedStringArray(0)
for i := 0; i <= 10; i++ {
array.Add(gconv.String(i))
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedArray1(t *testing.T) {
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
array := garray.NewSortedArray(0, func(v1, v2 interface{}) int {
return strings.Compare(gconv.String(v1), gconv.String(v2))
})
for i := 10; i > -1; i-- {
array.Add(gconv.String(i))
}
gtest.Assert(array.Slice(), expect)
}
func Test_SortedArray2(t *testing.T) {
expect := []string{"0","1","10","2","3","4","5","6","7","8","9"}
array := garray.NewSortedArray(0, func(v1, v2 interface{}) int {
return strings.Compare(gconv.String(v1), gconv.String(v2))
})
for i := 0; i <= 10; i++ {
array.Add(gconv.String(i))
}
gtest.Assert(array.Slice(), expect)
}

View File

@ -10,6 +10,7 @@ package gvar
import (
"gitee.com/johng/gf/g/container/gtype"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"time"
)
@ -92,6 +93,8 @@ func (v *Var) Interfaces() []interface{} { return gconv.Interfaces(v.Val()
func (v *Var) Time(format...string) time.Time { return gconv.Time(v.Val(), format...) }
func (v *Var) TimeDuration() time.Duration { return gconv.TimeDuration(v.Val()) }
func (v *Var) GTime(format...string) *gtime.Time { return gconv.GTime(v.Val(), format...) }
// 将变量转换为对象,注意 objPointer 参数必须为struct指针
func (v *Var) Struct(objPointer interface{}, attrMapping...map[string]string) error {
return gconv.Struct(v.Val(), objPointer, attrMapping...)

View File

@ -6,7 +6,10 @@
package gvar
import "time"
import (
"gitee.com/johng/gf/g/os/gtime"
"time"
)
// 只读变量接口
type VarRead interface {
@ -34,5 +37,6 @@ type VarRead interface {
Interfaces() []interface{}
Time(format ...string) time.Time
TimeDuration() time.Duration
GTime(format...string) *gtime.Time
Struct(objPointer interface{}, attrMapping ...map[string]string) error
}

View File

@ -1,11 +1,11 @@
package gdes
package gdes_test
import (
"testing"
"bytes"
"encoding/hex"
"fmt"
"gitee.com/johng/gf/g/encoding/gdes"
"gitee.com/johng/gf/g/crypto/gdes"
)
func TestDesECB(t *testing.T){

View File

@ -35,6 +35,7 @@ func EncryptFile(path string) string {
if e != nil {
return ""
}
defer f.Close()
h := md5.New()
_, e = io.Copy(h, f)
if e != nil {

View File

@ -4,7 +4,7 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// 数据库ORM.
// Package gdb provides ORM features for popular relationship databases/数据库ORM.
// 默认内置支持MySQL, 其他数据库需要手动import对应的数据库引擎第三方包.
package gdb
@ -12,7 +12,6 @@ import (
"database/sql"
"errors"
"fmt"
"gitee.com/johng/gf/g/container/gmap"
"gitee.com/johng/gf/g/container/gring"
"gitee.com/johng/gf/g/container/gtype"
"gitee.com/johng/gf/g/container/gvar"
@ -22,39 +21,42 @@ import (
"time"
)
const (
OPTION_INSERT = 0
OPTION_REPLACE = 1
OPTION_SAVE = 2
OPTION_IGNORE = 3
)
// 数据库操作接口
type Link interface {
// 打开数据库连接,建立数据库操作对象
Open(c *ConfigNode) (*sql.DB, error)
type DB interface {
// 建立数据库连接方法(开发者一般不需要直接调用)
Open(config *ConfigNode) (*sql.DB, error)
// SQL操作方法
Query(q string, args ...interface{}) (*sql.Rows, error)
Exec(q string, args ...interface{}) (sql.Result, error)
Prepare(q string) (*sql.Stmt, error)
// SQL操作方法 API
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error)
// 内部实现API的方法(不同数据库可覆盖这些方法实现自定义的操作)
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
doPrepare(link dbLink, query string) (*sql.Stmt, error)
doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error)
doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error)
doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error)
doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error)
// 数据库查询
GetAll(q string, args ...interface{}) (Result, error)
GetOne(q string, args ...interface{}) (Record, error)
GetValue(q string, args ...interface{}) (Value, error)
GetAll(query string, args ...interface{}) (Result, error)
GetOne(query string, args ...interface{}) (Record, error)
GetValue(query string, args ...interface{}) (Value, error)
GetCount(query string, args ...interface{}) (int, error)
GetStruct(obj interface{}, query string, args ...interface{}) error
// Ping
// 创建底层数据库master/slave链接对象
Master() (*sql.DB, error)
Slave() (*sql.DB, error)
// Ping
PingMaster() error
PingSlave() error
// 连接属性设置
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
SetConnMaxLifetime(n int)
// 开启事务操作
Begin() (*Tx, error)
Begin() (*TX, error)
// 数据表插入/更新/保存操作
Insert(table string, data Map) (sql.Result, error)
@ -74,25 +76,40 @@ type Link interface {
Table(tables string) *Model
From(tables string) *Model
// 内部方法
insert(table string, data Map, option uint8) (sql.Result, error)
batchInsert(table string, list List, batch int, option uint8) (sql.Result, error)
// 设置管理
SetDebug(debug bool)
SetSchema(schema string)
GetQueriedSqls() []*Sql
PrintQueriedSqls()
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
SetConnMaxLifetime(n int)
getQuoteCharLeft() string
getQuoteCharRight() string
handleSqlBeforeExec(q *string) *string
// 内部方法接口
getCache() (*gcache.Cache)
getChars() (charLeft string, charRight string)
getDebug() bool
filterFields(table string, data map[string]interface{}) map[string]interface{}
getTableFields(table string) (map[string]string, error)
handleSqlBeforeExec(sql string) string
}
// 执行底层数据库操作的核心接口
type dbLink interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
// 数据库链接对象
type Db struct {
link Link // 底层数据库类型管理对象
type dbBase struct {
db DB // 数据库对象
group string // 配置分组名称
charl string // SQL安全符号(左)
charr string // SQL安全符号(右)
debug *gtype.Bool // (默认关闭)是否开启调试模式,当开启时会启用一些调试特性
sqls *gring.Ring // (debug=true时有效)已执行的SQL列表
cache *gcache.Cache // 数据库缓存,包括底层连接池对象缓存及查询缓存;需要注意的是,事务查询不支持查询缓存
maxIdleConnCount *gtype.Int // 连接池最大限制的连接数
schema *gtype.String // 手动切换的数据库名称
maxIdleConnCount *gtype.Int // 连接池最大限制的连接数
maxOpenConnCount *gtype.Int // 连接池最大打开的连接数
maxConnLifetime *gtype.Int // (单位秒)连接对象可重复使用的时间长度
}
@ -104,7 +121,7 @@ type Sql struct {
Error error // 执行结果(nil为成功)
Start int64 // 执行开始时间(毫秒)
End int64 // 执行结束时间(毫秒)
Func string // 执行方法名称
Func string // 执行方法
}
// 返回数据表记录值
@ -117,27 +134,22 @@ type Record map[string]Value
type Result []Record
// 关联数组,绑定一条数据表记录(使用别名)
type Map = map[string]interface{}
type Map = map[string]interface{}
// 关联数组列表(索引从0开始的数组),绑定多条记录(使用别名)
type List = []Map
var (
// 支持的数据库类型map
driverMap = make(map[string]interface{})
// 数据库查询缓存对象map使用数据库连接名称作为键名键值为查询缓存对象
dbCaches = gmap.NewStringInterfaceMap()
const (
OPTION_INSERT = 0
OPTION_REPLACE = 1
OPTION_SAVE = 2
OPTION_IGNORE = 3
// 默认的连接池连接存活时间(秒)
gDEFAULT_CONN_MAX_LIFE_TIME = 30
)
func init() {
driverMap["mysql"] = linkMysql
driverMap["oracle"] = linkOracle
driverMap["sqlite"] = linkSqlite
driverMap["pgsql"] = linkPgsql
driverMap["mssql"] = linkMssql
}
// 使用默认/指定分组配置进行连接数据库集群配置项default
func New(groupName ...string) (*Db, error) {
func New(groupName ...string) (db DB, err error) {
group := config.d
if len(groupName) > 0 {
group = groupName[0]
@ -150,24 +162,30 @@ func New(groupName ...string) (*Db, error) {
}
if _, ok := config.c[group]; ok {
if node, err := getConfigNodeByGroup(group, true); err == nil {
link, err := getLinkByType(node.Type)
if err != nil {
return nil, err
}
db := &Db {
link : link,
base := &dbBase {
group : group,
charl : link.getQuoteCharLeft(),
charr : link.getQuoteCharRight(),
debug : gtype.NewBool(),
cache : gcache.New(),
schema : gtype.NewString(),
maxIdleConnCount : gtype.NewInt(),
maxOpenConnCount : gtype.NewInt(),
maxConnLifetime : gtype.NewInt(),
maxConnLifetime : gtype.NewInt(gDEFAULT_CONN_MAX_LIFE_TIME),
}
db.cache = dbCaches.GetOrSetFuncLock(group, func() interface{} {
return gcache.New()
}).(*gcache.Cache)
return db, nil
switch node.Type {
case "mysql":
base.db = &dbMysql{dbBase : base}
case "pgsql":
base.db = &dbPgsql{dbBase : base}
case "mssql":
base.db = &dbMssql{dbBase : base}
case "sqlite":
base.db = &dbSqlite{dbBase : base}
case "oracle":
base.db = &dbOracle{dbBase : base}
default:
return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type))
}
return base.db, nil
} else {
return nil, err
}
@ -219,6 +237,13 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
for i := 0; i < len(cg); i++ {
total += cg[i].Priority * 100
}
// 如果total为0表示所有连接都没有配置priority属性那么默认都是1
if total == 0 {
for i := 0; i < len(cg); i++ {
cg[i].Priority = 1
total += cg[i].Priority * 100
}
}
// 不能取到末尾的边界点
r := grand.Rand(0, total)
if r > 0 {
@ -238,51 +263,36 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
return nil
}
// 根据配置的数据库类型获得Link接口对象
func getLinkByType(dbType string) (Link, error) {
if dblink, ok := driverMap[dbType]; ok == false {
return nil, errors.New(fmt.Sprintf("unsupported db type '%s'", dbType))
} else {
return dblink.(Link), nil
}
}
// 获得底层数据库链接对象
func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
// 负载均衡
node, err := getConfigNodeByGroup(db.group, master)
node, err := getConfigNodeByGroup(bs.group, master)
if err != nil {
return nil, err
}
// 类型对象
link, err := getLinkByType(node.Type)
if err != nil {
return nil, err
// 默认值设定
if node.Charset == "" {
node.Charset = "utf8"
}
// 检查缓存连接池对象
cacheKey := node.String()
if v := db.cache.Get(cacheKey); v != nil {
return v.(*sql.DB), nil
}
v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = link.Open(node)
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = bs.db.Open(node)
if err != nil {
return nil
}
if n := db.maxIdleConnCount.Val(); n > 0 {
if n := bs.maxIdleConnCount.Val(); n > 0 {
sqlDb.SetMaxIdleConns(n)
} else if node.MaxIdleConnCount > 0 {
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
}
if n := db.maxOpenConnCount.Val(); n > 0 {
if n := bs.maxOpenConnCount.Val(); n > 0 {
sqlDb.SetMaxOpenConns(n)
} else if node.MaxOpenConnCount > 0 {
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
}
if n := db.maxConnLifetime.Val(); n > 0 {
if n := bs.maxConnLifetime.Val(); n > 0 {
sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second)
} else if node.MaxConnLifetime > 0 {
sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second)
@ -292,15 +302,24 @@ func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
if v != nil && sqlDb == nil {
sqlDb = v.(*sql.DB)
}
// 是否手动选择数据库
if v := bs.schema.Val(); v != "" {
sqlDb.Exec("USE " + v)
}
return
}
// 切换操作的数据库(注意该切换是全局的)
func (bs *dbBase) SetSchema(schema string) {
bs.schema.Set(schema)
}
// 创建底层数据库master链接对象
func (db *Db) Master() (*sql.DB, error) {
return db.getSqlDb(true)
func (bs *dbBase) Master() (*sql.DB, error) {
return bs.getSqlDb(true)
}
// 创建底层数据库slave链接对象
func (db *Db) Slave() (*sql.DB, error) {
return db.getSqlDb(false)
func (bs *dbBase) Slave() (*sql.DB, error) {
return bs.getSqlDb(false)
}

View File

@ -8,39 +8,29 @@
package gdb
import (
"fmt"
"errors"
"strings"
"reflect"
"database/sql"
"gitee.com/johng/gf/g/util/gstr"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/container/gring"
"errors"
"fmt"
"gitee.com/johng/gf/g/os/gcache"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/container/gvar"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
"reflect"
"strings"
)
const (
gDEFAULT_DEBUG_SQL_LENGTH = 1000 // 默认调试模式下记录的SQL条数
)
// 是否开启调试服务
func (db *Db) SetDebug(debug bool) {
db.debug.Set(debug)
if debug && db.sqls == nil {
db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
}
}
// 获取已经执行的SQL列表(仅在debug=true时有效)
func (db *Db) GetQueriedSqls() []*Sql {
if db.sqls == nil {
func (bs *dbBase) GetQueriedSqls() []*Sql {
if bs.sqls == nil {
return nil
}
sqls := make([]*Sql, 0)
db.sqls.Prev()
db.sqls.RLockIteratorPrev(func(value interface{}) bool {
bs.sqls.Prev()
bs.sqls.RLockIteratorPrev(func(value interface{}) bool {
if value == nil {
return false
}
@ -51,8 +41,8 @@ func (db *Db) GetQueriedSqls() []*Sql {
}
// 打印已经执行的SQL列表(仅在debug=true时有效)
func (db *Db) PrintQueriedSqls() {
sqls := db.GetQueriedSqls()
func (bs *dbBase) PrintQueriedSqls() {
sqls := bs.GetQueriedSqls()
for k, v := range sqls {
fmt.Println(len(sqls) - k, ":")
fmt.Println(" Sql :", v.Sql)
@ -61,143 +51,110 @@ func (db *Db) PrintQueriedSqls() {
fmt.Println(" Start:", gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"))
fmt.Println(" End :", gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"))
fmt.Println(" Cost :", v.End - v.Start, "ms")
fmt.Println(" Func :", v.Func)
}
}
// 打印SQL对象(仅在debug=true时有效)
func (db *Db) printSql(v *Sql) {
s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args,
gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"),
gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"),
v.End - v.Start, v.Func,
)
if v.Error != nil {
s += "\nError: " + v.Error.Error()
glog.Backtrace(true, 2).Error(s)
} else {
glog.Debug(s)
}
}
// 数据库sql查询操作主要执行查询
func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) {
var err error
var rows *sql.Rows
var slave *sql.DB
slave, err = db.Slave();
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := bs.db.Slave()
if err != nil {
return nil,err
}
p := db.link.handleSqlBeforeExec(&query)
if db.debug.Val() {
militime1 := gtime.Millisecond()
rows, err = slave.Query(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
return bs.db.doQuery(link, query, args...)
}
// 数据库sql查询操作主要执行查询
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
rows, err = link.Query(query, args...)
mTime2 := gtime.Millisecond()
s := &Sql {
Sql : query,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "DB:Query",
Start : mTime1,
End : mTime2,
}
db.sqls.Put(s)
db.printSql(s)
bs.sqls.Put(s)
printSql(s)
} else {
rows, err = slave.Query(*p, args ...)
rows, err = link.Query(query, args ...)
}
if err == nil {
return rows, nil
} else {
err = db.formatError(err, p, args...)
err = formatError(err, query, args...)
}
return nil, err
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) {
var err error
var result sql.Result
var master *sql.DB
master, err = db.Master();
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil,err
}
p := db.link.handleSqlBeforeExec(&query)
if db.debug.Val() {
militime1 := gtime.Millisecond()
result, err = master.Exec(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
return bs.db.doExec(link, query, args...)
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
result, err = link.Exec(query, args ...)
mTime2 := gtime.Millisecond()
s := &Sql{
Sql : query,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "DB:Exec",
Start : mTime1,
End : mTime2,
}
db.sqls.Put(s)
db.printSql(s)
bs.sqls.Put(s)
printSql(s)
} else {
result, err = master.Exec(*p, args ...)
result, err = link.Exec(query, args ...)
}
return result, db.formatError(err, p, args...)
return result, formatError(err, query, args...)
}
// 格式化错误信息
func (db *Db) formatError(err error, query *string, args ...interface{}) error {
if err != nil {
errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error())
errstr += fmt.Sprintf("DB QUERY: %s\n", *query)
if len(args) > 0 {
errstr += fmt.Sprintf("DB PARAM: %v\n", args)
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上
func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
err := (error)(nil)
link := (dbLink)(nil)
if len(execOnMaster) > 0 && execOnMaster[0] {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
} else {
if link, err = bs.db.Slave(); err != nil {
return nil, err
}
err = errors.New(errstr)
}
return err
return bs.db.doPrepare(link, query)
}
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
return link.Prepare(query)
}
// 数据库查询,获取查询结果集,以列表结构返回
func (db *Db) GetAll(query string, args ...interface{}) (Result, error) {
// 执行sql
rows, err := db.Query(query, args ...)
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := bs.Query(query, args ...)
if err != nil || rows == nil {
return nil, err
}
// 列名称列表
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
records := make(Result, 0)
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return records, err
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
for i, col := range values {
v := make([]byte, len(col))
copy(v, col)
row[columns[i]] = gvar.New(v, false)
}
records = append(records, row)
}
return records, nil
defer rows.Close()
return rowsToResult(rows)
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
func (db *Db) GetOne(query string, args ...interface{}) (Record, error) {
list, err := db.GetAll(query, args ...)
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
list, err := bs.GetAll(query, args ...)
if err != nil {
return nil, err
}
@ -208,18 +165,17 @@ func (db *Db) GetOne(query string, args ...interface{}) (Record, error) {
}
// 数据库查询获取查询结果记录自动映射数据到给定的struct对象中
func (db *Db) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := db.GetOne(query, args...)
func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := bs.GetOne(query, args...)
if err != nil {
return err
}
return one.ToStruct(obj)
}
// 数据库查询,获取查询字段值
func (db *Db) GetValue(query string, args ...interface{}) (Value, error) {
one, err := db.GetOne(query, args ...)
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
one, err := bs.GetOne(query, args ...)
if err != nil {
return nil, err
}
@ -230,44 +186,20 @@ func (db *Db) GetValue(query string, args ...interface{}) (Value, error) {
}
// 数据库查询,获取查询数量
func (db *Db) GetCount(query string, args ...interface{}) (int, error) {
val, err := db.GetValue(query, args ...)
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := bs.GetValue(query, args ...)
if err != nil {
return 0, err
}
return gconv.Int(val), nil
}
// 数据表查询其中tables可以是多个联表查询语句这种查询方式较复杂建议使用链式操作
func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
if condition != nil {
s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition))
}
if len(groupBy) > 0 {
s += fmt.Sprintf("GROUP BY %s ", groupBy)
}
if len(orderBy) > 0 {
s += fmt.Sprintf("ORDER BY %s ", orderBy)
}
if limit > 0 {
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
}
return db.GetAll(s, args ... )
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (db *Db) Prepare(query string) (*sql.Stmt, error) {
if master, err := db.Master(); err != nil {
return nil, err
} else {
return master.Prepare(query)
}
return value.Int(), nil
}
// ping一下判断或保持数据库链接(master)
func (db *Db) PingMaster() error {
if master, err := db.Master(); err != nil {
func (bs *dbBase) PingMaster() error {
if master, err := bs.db.Master(); err != nil {
return err
} else {
return master.Ping()
@ -275,8 +207,8 @@ func (db *Db) PingMaster() error {
}
// ping一下判断或保持数据库链接(slave)
func (db *Db) PingSlave() error {
if slave, err := db.Slave(); err != nil {
func (bs *dbBase) PingSlave() error {
if slave, err := bs.db.Slave(); err != nil {
return err
} else {
return slave.Ping()
@ -285,13 +217,13 @@ func (db *Db) PingSlave() error {
// 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略
// 只有在tx.Commit/tx.Rollback时链接会自动Close
func (db *Db) Begin() (*Tx, error) {
if master, err := db.Master(); err != nil {
func (bs *dbBase) Begin() (*TX, error) {
if master, err := bs.db.Master(); err != nil {
return nil, err
} else {
if tx, err := master.Begin(); err == nil {
return &Tx {
db : db,
return &TX {
db : bs.db,
tx : tx,
master : master,
}, nil
@ -301,17 +233,19 @@ func (db *Db) Begin() (*Tx, error) {
}
}
// 根据insert选项获得操作名称
func (db *Db) getInsertOperationByOption(option uint8) string {
oper := "INSERT"
switch option {
case OPTION_REPLACE:
oper = "REPLACE"
case OPTION_SAVE:
case OPTION_IGNORE:
oper = "INSERT IGNORE"
}
return oper
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) Save(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_SAVE)
}
// insert、replace, save ignore操作
@ -319,95 +253,102 @@ func (db *Db) getInsertOperationByOption(option uint8) string {
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (db *Db) insert(table string, data Map, option uint8) (sql.Result, error) {
func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) {
var fields []string
var values []string
var params []interface{}
charl, charr := bs.db.getChars()
for k, v := range data {
fields = append(fields, db.charl + k + db.charr)
fields = append(fields, charl + k + charr)
values = append(values, "?")
params = append(params, v)
}
operation := db.getInsertOperationByOption(option)
operation := getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range data {
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
db.charl, k, db.charr,
db.charl, k, db.charr,
charl, k, charr,
charl, k, charr,
),
)
}
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
return db.Exec(
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","),
updatestr),
params...
)
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updatestr),
params...)
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (db *Db) Insert(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_INSERT)
// CURD操作:批量数据指定批次量写入
func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (db *Db) Replace(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_REPLACE)
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (db *Db) Save(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_SAVE)
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE)
}
// 批量写入数据
func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) {
var keys []string
var values []string
var bvalues []string
var params []interface{}
var result sql.Result
var size = len(list)
// 判断长度
if size < 1 {
if len(list) < 1 {
return result, errors.New("empty data list")
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
return
}
}
// 首先获取字段名称及记录长度
for k, _ := range list[0] {
keys = append(keys, k)
values = append(values, "?")
}
keyStr := db.charl + strings.Join(keys, db.charl + "," + db.charr) + db.charr
charl, charr := bs.db.getChars()
keyStr := charl + strings.Join(keys, charl + "," + charr) + charr
valueHolderStr := "(" + strings.Join(values, ",") + ")"
// 操作判断
operation := db.getInsertOperationByOption(option)
operation := getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
db.charl, k, db.charr,
db.charl, k, db.charr,
charl, k, charr,
charl, k, charr,
),
)
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < size; i++ {
for i := 0; i < len(list); i++ {
for _, k := range keys {
params = append(params, list[i][k])
}
bvalues = append(bvalues, valueHolderStr)
if len(bvalues) == batch {
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
@ -421,7 +362,7 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql
}
// 处理最后不构成指定批量的数据
if len(bvalues) > 0 {
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
@ -433,32 +374,28 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql
return result, nil
}
// CURD操作:批量数据指定批次量写入
func (db *Db) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_INSERT)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (db *Db) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_REPLACE)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (db *Db) BatchSave(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_SAVE)
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doUpdate(link, table, data, condition, args ...)
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (db *Db) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
var params []interface{}
var updates string
refValue := reflect.ValueOf(data)
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) {
params := ([]interface{})(nil)
updates := ""
charl, charr := bs.db.getChars()
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
keys := refValue.MapKeys()
for _, k := range keys {
fields = append(fields, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr))
fields = append(fields, fmt.Sprintf("%s%s%s=?", charl, k, charr))
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
}
updates = strings.Join(fields, ",")
@ -468,34 +405,65 @@ func (db *Db) Update(table string, data interface{}, condition interface{}, args
for _, v := range args {
params = append(params, gconv.String(v))
}
return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...)
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
newWhere, newArgs := formatCondition(condition, params)
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...)
}
// CURD操作:删除数据
func (db *Db) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...)
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doDelete(link, table, condition, args ...)
}
// 格式化SQL查询条件
func (db *Db) formatCondition(condition interface{}) (where string) {
if reflect.ValueOf(condition).Kind() == reflect.Map {
ks := reflect.ValueOf(condition).MapKeys()
vs := reflect.ValueOf(condition)
for _, k := range ks {
key := gconv.String(k.Interface())
value := gconv.String(vs.MapIndex(k).Interface())
isNum := gstr.IsNumeric(value)
if len(where) > 0 {
where += " AND "
}
if isNum || value == "?" {
where += key + "=" + value
} else {
where += key + "='" + value + "'"
// CURD操作:删除数据
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatCondition(condition, args)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...)
}
// 获得缓存对象
func (bs *dbBase) getCache() *gcache.Cache {
return bs.cache
}
// 将map的数据按照fields进行过滤只保留与表字段同名的数据
func (bs *dbBase) filterFields(table string, data map[string]interface{}) map[string]interface{} {
if fields, err := bs.db.getTableFields(table); err == nil {
for k, _ := range data {
if _, ok := fields[k]; !ok {
delete(data, k)
}
}
} else {
where += gconv.String(condition)
}
return data
}
// 获得指定表表的数据结构构造成map哈希表返回其中键名为表字段名称键值暂无用途(默认为字段数据类型).
func (bs *dbBase) getTableFields(table string) (fields map[string]string, err error) {
// 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署)
v := bs.cache.GetOrSetFunc("table_fields_" + table, func() interface{} {
result := (Result)(nil)
charl, charr := bs.db.getChars()
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charl, table, charr))
if err != nil {
return nil
}
fields = make(map[string]string)
for _, m := range result {
fields[m["Field"].String()] = m["Type"].String()
}
return fields
}, 0)
if err == nil {
fields = v.(map[string]string)
}
return
}

View File

@ -9,6 +9,7 @@ package gdb
import (
"fmt"
"gitee.com/johng/gf/g/container/gring"
"sync"
)
@ -114,6 +115,13 @@ func AddDefaultConfigGroup (nodes ConfigGroup) {
AddConfigGroup(DEFAULT_GROUP_NAME, nodes)
}
// 添加一台数据库服务器配置
func GetConfig (group string) ConfigGroup {
config.RLock()
defer config.RUnlock()
return config.c[group]
}
// 设置默认链接的数据库链接配置项(默认是 default)
func SetDefaultGroup (groupName string) {
config.Lock()
@ -122,19 +130,19 @@ func SetDefaultGroup (groupName string) {
}
// 设置数据库连接池中空闲链接的大小
func (db *Db) SetMaxIdleConns(n int) {
db.maxIdleConnCount.Set(n)
func (bs *dbBase) SetMaxIdleConns(n int) {
bs.maxIdleConnCount.Set(n)
}
// 设置数据库连接池最大打开的链接数量
func (db *Db) SetMaxOpenConns(n int) {
db.maxOpenConnCount.Set(n)
func (bs *dbBase) SetMaxOpenConns(n int) {
bs.maxOpenConnCount.Set(n)
}
// 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃
// 如果 d <= 0 表示该链接会一直重复利用
func (db *Db) SetConnMaxLifetime(n int) {
db.maxConnLifetime.Set(n)
func (bs *dbBase) SetConnMaxLifetime(n int) {
bs.maxConnLifetime.Set(n)
}
// 节点配置转换为字符串
@ -146,4 +154,17 @@ func (node *ConfigNode) String() string {
node.Name, node.Type, node.Role, node.Charset,
node.MaxIdleConnCount, node.MaxOpenConnCount, node.MaxConnLifetime,
)
}
// 是否开启调试服务
func (bs *dbBase) SetDebug(debug bool) {
bs.debug.Set(debug)
if debug && bs.sqls == nil {
bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
}
}
// 获取是否开启调试服务
func (bs *dbBase) getDebug() bool {
return bs.debug.Val()
}

158
g/database/gdb/gdb_func.go Normal file
View File

@ -0,0 +1,158 @@
// Copyright 2017-2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
package gdb
import (
"bytes"
"database/sql"
"errors"
"fmt"
"gitee.com/johng/gf/g/container/gvar"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
"gitee.com/johng/gf/g/util/gstr"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"reflect"
"strings"
)
// 将数据查询的列表数据*sql.Rows转换为Result类型
func rowsToResult(rows *sql.Rows) (Result, error) {
// 列名称列表
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
records := make(Result, 0)
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return records, err
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
for i, col := range values {
if col == nil {
row[columns[i]] = gvar.New(nil, false)
} else {
v := make([]byte, len(col))
copy(v, col)
row[columns[i]] = gvar.New(v, false)
}
}
records = append(records, row)
}
return records, nil
}
// 格式化SQL查询条件
func formatCondition(where interface{}, args []interface{}) (string, []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
if reflect.ValueOf(where).Kind() == reflect.Map {
ks := reflect.ValueOf(where).MapKeys()
vs := reflect.ValueOf(where)
for _, k := range ks {
key := gconv.String(k.Interface())
value := gconv.String(vs.MapIndex(k).Interface())
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}
if gstr.IsNumeric(value) || value == "?" {
buffer.WriteString(key + "=" + value)
} else {
buffer.WriteString(key + "='" + value + "'")
}
}
} else {
buffer.Write(gconv.Bytes(where))
}
if buffer.Len() == 0 {
buffer.WriteString("1")
}
// 查询条件处理
newWhere := buffer.String()
newArgs := make([]interface{}, 0)
if len(args) > 0 {
for index, arg := range args {
rv := reflect.ValueOf(arg)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Slice: fallthrough
case reflect.Array:
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
}
counter := 0
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
counter++
if counter == index + 1 {
return "?" + strings.Repeat(",?", rv.Len() - 1)
}
return s
})
default:
newArgs = append(newArgs, arg)
}
}
}
return newWhere, newArgs
}
// 打印SQL对象(仅在debug=true时有效)
func printSql(v *Sql) {
s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args,
gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"),
gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"),
v.End - v.Start,
v.Func,
)
if v.Error != nil {
s += "\nError: " + v.Error.Error()
glog.Backtrace(true, 2).Error(s)
} else {
glog.Debug(s)
}
}
// 格式化错误信息
func formatError(err error, query string, args ...interface{}) error {
if err != nil {
errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error())
errstr += fmt.Sprintf("DB QUERY: %s\n", query)
if len(args) > 0 {
errstr += fmt.Sprintf("DB PARAM: %v\n", args)
}
err = errors.New(errstr)
}
return err
}
// 根据insert选项获得操作名称
func getInsertOperationByOption(option int) string {
oper := "INSERT"
switch option {
case OPTION_REPLACE:
oper = "REPLACE"
case OPTION_SAVE:
case OPTION_IGNORE:
oper = "INSERT IGNORE"
}
return oper
}

View File

@ -12,12 +12,15 @@ import (
"database/sql"
"gitee.com/johng/gf/g/util/gconv"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"reflect"
"strings"
)
// 数据库链式操作模型对象
type Model struct {
tx *Tx // 数据库事务对象
db *Db // 数据库操作对象
db DB // 数据库操作对象
tx *TX // 数据库事务对象
tablesInit string // 初始化Model时的表名称(可以是多个)
tables string // 数据库操作表
fields string // 操作字段
where string // 操作条件
@ -28,123 +31,213 @@ type Model struct {
limit int // 分页条数
data interface{} // 操作记录(支持Map/List/string类型)
batch int // 批量操作条数
filter bool // 是否按照表字段过滤data参数
cacheEnabled bool // 当前SQL操作是否开启查询缓存功能
cacheTime int // 查询缓存时间
cacheName string // 查询缓存名称
}
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
func (db *Db) Table(tables string) (*Model) {
return &Model{
db: db,
tables: tables,
fields: "*",
func (bs *dbBase) Table(tables string) (*Model) {
return &Model {
db : bs.db,
tablesInit : tables,
tables : tables,
fields : "*",
}
}
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
func (db *Db) From(tables string) (*Model) {
return db.Table(tables)
func (bs *dbBase) From(tables string) (*Model) {
return bs.db.Table(tables)
}
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
func (tx *Tx) Table(tables string) (*Model) {
func (tx *TX) Table(tables string) (*Model) {
return &Model{
db: tx.db,
tx: tx,
tables: tables,
db : tx.db,
tx : tx,
tablesInit : tables,
tables : tables,
}
}
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
func (tx *Tx) From(tables string) (*Model) {
func (tx *TX) From(tables string) (*Model) {
return tx.Table(tables)
}
// 克隆一个当前对象
func (md *Model) Clone() *Model {
newModel := (*Model)(nil)
if md.tx != nil {
newModel = md.tx.Table(md.tablesInit)
} else {
newModel = md.db.Table(md.tablesInit)
}
*newModel = *md
return newModel
}
// 链式操作,左联表
func (md *Model) LeftJoin(joinTable string, on string) (*Model) {
md.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", joinTable, on)
return md
model := md.Clone()
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", joinTable, on)
return model
}
// 链式操作,右联表
func (md *Model) RightJoin(joinTable string, on string) (*Model) {
md.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", joinTable, on)
return md
model := md.Clone()
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", joinTable, on)
return model
}
// 链式操作,内联表
func (md *Model) InnerJoin(joinTable string, on string) (*Model) {
md.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", joinTable, on)
return md
model := md.Clone()
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", joinTable, on)
return model
}
// 链式操作,查询字段
func (md *Model) Fields(fields string) (*Model) {
md.fields = fields
return md
model := md.Clone()
model.fields = fields
return model
}
// 链式操作,过滤字段
func (md *Model) Filter() (*Model) {
model := md.Clone()
model.filter = true
return model
}
// 链式操作condition支持string & gdb.Map
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
md.where = md.db.formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
return md
model := md.Clone()
newWhere, newArgs := formatCondition(where, args)
model.where = newWhere
model.whereArgs = append(model.whereArgs, newArgs...)
// 支持 Where("uid", 1)这种格式
if len(args) == 1 && strings.Index(model.where , "?") < 0 {
model.where += "=?"
}
return model
}
// 链式操作添加AND条件到Where中
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
md.where += " AND " + md.db.formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
return md
model := md.Clone()
newWhere, newArgs := formatCondition(where, args)
model.where += " AND " + newWhere
model.whereArgs = append(model.whereArgs, newArgs...)
return model
}
// 链式操作添加OR条件到Where中
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
md.where += " OR " + md.db.formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
return md
model := md.Clone()
newWhere, newArgs := formatCondition(where, args)
model.where += " OR " + newWhere
model.whereArgs = append(model.whereArgs, newArgs...)
return model
}
// 链式操作group by
func (md *Model) GroupBy(groupBy string) (*Model) {
md.groupBy = groupBy
return md
model := md.Clone()
model.groupBy = groupBy
return model
}
// 链式操作order by
func (md *Model) OrderBy(orderBy string) (*Model) {
md.orderBy = orderBy
return md
model := md.Clone()
model.orderBy = orderBy
return model
}
// 链式操作limit
func (md *Model) Limit(start int, limit int) (*Model) {
md.start = start
md.limit = limit
return md
model := md.Clone()
model.start = start
model.limit = limit
return model
}
// 链式操作,翻页
// @author ymrjqyy
func (md *Model) ForPage(page, limit int) (*Model) {
md.start = (page - 1) * limit
md.limit = limit
return md
model := md.Clone()
model.start = (page - 1) * limit
model.limit = limit
return model
}
// 设置批处理的大小
func (md *Model) Batch(batch int) *Model {
model := md.Clone()
model.batch = batch
return model
}
// 查询缓存/清除缓存操作,需要注意的是,事务查询不支持缓存。
// 当time < 0时表示清除缓存 time=0时表示不过期, time > 0时表示过期时间time过期时间单位
// name表示自定义的缓存名称便于业务层精准定位缓存项(如果业务层需要手动清理时,必须指定缓存名称)
// 例如:查询缓存时设置名称,清理缓存时可以给定清理的缓存名称进行精准清理。
func (md *Model) Cache(time int, name ... string) *Model {
model := md.Clone()
model.cacheTime = time
if len(name) > 0 {
model.cacheName = name[0]
}
// 查询缓存特性不支持事务操作
if model.tx == nil {
model.cacheEnabled = true
}
return model
}
// 链式操作操作数据记录项可以是string/Map, 也可以是key,value,key,value,...
func (md *Model) Data(data ...interface{}) (*Model) {
model := md.Clone()
if len(data) > 1 {
m := make(map[string]interface{})
for i := 0; i < len(data); i += 2 {
m[gconv.String(data[i])] = data[i+1]
}
md.data = m
model.data = m
} else {
md.data = data[0]
switch data[0].(type) {
case List:
model.data = data[0]
case Map:
model.data = data[0]
default:
rv := reflect.ValueOf(data[0])
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
case reflect.Slice: fallthrough
case reflect.Array:
list := make(List, rv.Len())
for i := 0; i < rv.Len(); i++ {
list[i] = gconv.Map(rv.Index(i).Interface())
}
model.data = list
case reflect.Map:
model.data = gconv.Map(data[0])
default:
model.data = data[0]
}
}
}
return md
return model
}
// 链式操作, CURD - Insert/BatchInsert
@ -163,16 +256,24 @@ func (md *Model) Insert() (result sql.Result, err error) {
if md.batch > 0 {
batch = md.batch
}
if md.filter {
for k, m := range list {
list[k] = md.db.filterFields(md.tables, m)
}
}
if md.tx == nil {
return md.db.BatchInsert(md.tables, list, batch)
} else {
return md.tx.BatchInsert(md.tables, list, batch)
}
} else if dataMap, ok := md.data.(Map); ok {
} else if data, ok := md.data.(Map); ok {
if md.filter {
data = md.db.filterFields(md.tables, data)
}
if md.tx == nil {
return md.db.Insert(md.tables, dataMap)
return md.db.Insert(md.tables, data)
} else {
return md.tx.Insert(md.tables, dataMap)
return md.tx.Insert(md.tables, data)
}
}
return nil, errors.New("inserting into table with invalid data type")
@ -194,16 +295,24 @@ func (md *Model) Replace() (result sql.Result, err error) {
if md.batch > 0 {
batch = md.batch
}
if md.filter {
for k, m := range list {
list[k] = md.db.filterFields(md.tables, m)
}
}
if md.tx == nil {
return md.db.BatchReplace(md.tables, list, batch)
} else {
return md.tx.BatchReplace(md.tables, list, batch)
}
} else if dataMap, ok := md.data.(Map); ok {
} else if data, ok := md.data.(Map); ok {
if md.filter {
data = md.db.filterFields(md.tables, data)
}
if md.tx == nil {
return md.db.Insert(md.tables, dataMap)
return md.db.Replace(md.tables, data)
} else {
return md.tx.Insert(md.tables, dataMap)
return md.tx.Replace(md.tables, data)
}
}
return nil, errors.New("replacing into table with invalid data type")
@ -225,16 +334,24 @@ func (md *Model) Save() (result sql.Result, err error) {
if md.batch > 0 {
batch = md.batch
}
if md.filter {
for k, m := range list {
list[k] = md.db.filterFields(md.tables, m)
}
}
if md.tx == nil {
return md.db.BatchSave(md.tables, list, batch)
} else {
return md.tx.BatchSave(md.tables, list, batch)
}
} else if dataMap, ok := md.data.(Map); ok {
} else if data, ok := md.data.(Map); ok {
if md.filter {
data = md.db.filterFields(md.tables, data)
}
if md.tx == nil {
return md.db.Save(md.tables, dataMap)
return md.db.Save(md.tables, data)
} else {
return md.tx.Save(md.tables, dataMap)
return md.tx.Save(md.tables, data)
}
}
return nil, errors.New("saving into table with invalid data type")
@ -250,6 +367,13 @@ func (md *Model) Update() (result sql.Result, err error) {
if md.data == nil {
return nil, errors.New("updating table with empty data")
}
if md.filter {
if data, ok := md.data.(Map); ok {
if md.filter {
md.data = md.db.filterFields(md.tables, data)
}
}
}
if md.tx == nil {
return md.db.Update(md.tables, md.data, md.where, md.whereArgs ...)
} else {
@ -264,9 +388,6 @@ func (md *Model) Delete() (result sql.Result, err error) {
md.checkAndRemoveCache()
}
}()
if md.where == "" {
return nil, errors.New("where is required while deleting")
}
if md.tx == nil {
return md.db.Delete(md.tables, md.where, md.whereArgs...)
} else {
@ -274,36 +395,14 @@ func (md *Model) Delete() (result sql.Result, err error) {
}
}
// 设置批处理的大小
func (md *Model) Batch(batch int) *Model {
md.batch = batch
return md
}
// 查询缓存/清除缓存操作,需要注意的是,事务查询不支持缓存。
// 当time < 0时表示清除缓存 time=0时表示不过期, time > 0时表示过期时间time过期时间单位
// name表示自定义的缓存名称便于业务层精准定位缓存项(如果业务层需要手动清理时,必须指定缓存名称)
// 例如:查询缓存时设置名称,清理缓存时可以给定清理的缓存名称进行精准清理。
func (md *Model) Cache(time int, name ... string) *Model {
md.cacheTime = time
if len(name) > 0 {
md.cacheName = name[0]
}
// 查询缓存特性不支持事务操作
if md.tx == nil {
md.cacheEnabled = true
}
return md
}
// 链式操作select
func (md *Model) Select() (Result, error) {
return md.getAll(md.getFormattedSql(), md.whereArgs...)
return md.All()
}
// 链式操作,查询所有记录
func (md *Model) All() (Result, error) {
return md.Select()
return md.getAll(md.getFormattedSql(), md.whereArgs...)
}
// 链式操作,查询单条记录
@ -342,6 +441,9 @@ func (md *Model) Struct(obj interface{}) error {
// 链式操作查询数量fields可以为空也可以自定义查询字段
// 当给定自定义查询字段时该字段必须为数量结果否则会引起歧义使用如md.Fields("COUNT(id)")
func (md *Model) Count() (int, error) {
defer func(fields string) {
md.fields = fields
}(md.fields)
if md.fields == "" || md.fields == "*" {
md.fields = "COUNT(1)"
} else {
@ -364,29 +466,30 @@ func (md *Model) Count() (int, error) {
}
// 查询操作对底层SQL操作的封装
func (md *Model) getAll(sql string, args ...interface{}) (result Result, err error) {
var cacheKey string
func (md *Model) getAll(query string, args ...interface{}) (result Result, err error) {
cacheKey := ""
// 查询缓存查询处理
if md.cacheEnabled {
cacheKey = md.cacheName
if len(cacheKey) == 0 {
cacheKey = sql + "/" + gconv.String(args)
cacheKey = query + "/" + gconv.String(args)
}
if v := md.db.cache.Get(cacheKey); v != nil {
if v := md.db.getCache().Get(cacheKey); v != nil {
return v.(Result), nil
}
}
if md.tx == nil {
result, err = md.db.GetAll(sql, args...)
result, err = md.db.GetAll(query, args...)
} else {
result, err = md.tx.GetAll(sql, args...)
result, err = md.tx.GetAll(query, args...)
}
// 查询缓存保存处理
if len(cacheKey) > 0 && err == nil {
if md.cacheTime < 0 {
md.db.cache.Remove(cacheKey)
md.db.getCache().Remove(cacheKey)
} else {
md.db.cache.Set(cacheKey, result, md.cacheTime*1000)
md.db.getCache().Set(cacheKey, result, md.cacheTime*1000)
}
}
return result, err
@ -395,7 +498,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err
// 检查是否需要查询查询缓存
func (md *Model) checkAndRemoveCache() {
if md.cacheEnabled && md.cacheTime < 0 && len(md.cacheName) > 0 {
md.db.cache.Remove(md.cacheName)
md.db.getCache().Remove(md.cacheName)
}
}
@ -424,11 +527,10 @@ func (md *Model) getFormattedSql() string {
// @author ymrjqyy
// @author 2018-08-15
func (md *Model) Chunk(limit int, callback func(result Result, err error) bool) {
var page = 1
page := 1
for {
md.ForPage(page, limit)
sqls := md.getFormattedSql()
data, err := md.getAll(sqls, md.whereArgs...)
data, err := md.getAll(md.getFormattedSql(), md.whereArgs...)
if err != nil {
callback(nil, err)
break

View File

@ -22,20 +22,19 @@ import (
)
var linkMssql = &dbmssql{}
// 数据库链接对象
type dbmssql struct {
Db
type dbMssql struct {
*dbBase
}
// 创建SQL操作对象
func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
source := ""
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
config.User, config.Pass, config.Host, config.Port, config.Name)
}
if db, err := sql.Open("sqlserver", source); err == nil {
return db, nil
@ -44,43 +43,38 @@ func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) {
}
}
// 获得关键字操作符 - 左
func (db *dbmssql) getQuoteCharLeft() string {
return "\""
}
// 获得关键字操作符 - 右
func (db *dbmssql) getQuoteCharRight() string {
return "\""
// 获得关键字操作符
func (db *dbMssql) getChars () (charLeft string, charRight string) {
return "\"", "\""
}
// 在执行sql之前对sql进行进一步处理
func (db *dbmssql) handleSqlBeforeExec(q *string) *string {
func (db *dbMssql) handleSqlBeforeExec(query string) string {
index := 0
str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string {
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
index++
return fmt.Sprintf("@p%d", index)
})
str, _ = gregex.ReplaceString("\"", "", str)
return db.parseSql(&str)
return db.parseSql(str)
}
//将MYSQL的SQL语法转换为MSSQL的语法
//1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换
func (db *dbmssql) parseSql(sql *string) *string {
func (db *dbMssql) parseSql(sql string) string {
//下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理如有LIMIT则将LIMIT的关键字也匹配出
patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
if gregex.IsMatchString(patten, *sql) == false {
if gregex.IsMatchString(patten, sql) == false {
fmt.Println("not matched..")
return sql
}
res, err := gregex.MatchAllString(patten, *sql)
res, err := gregex.MatchAllString(patten, sql)
if err != nil {
fmt.Println("MatchString error.", err)
return nil
return ""
}
index := 0
@ -96,17 +90,17 @@ func (db *dbmssql) parseSql(sql *string) *string {
}
//不含LIMIT则不处理
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false {
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
break
}
//判断SQL中是否含有order by
selectStr := ""
orderbyStr := ""
haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql)
haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
if haveOrderby {
//取order by 前面的字符串
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql)
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "ORDER BY") == false{
break
@ -114,13 +108,13 @@ func (db *dbmssql) parseSql(sql *string) *string {
selectStr = queryExpr[2]
//取order by表达式的值
orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", *sql)
orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql)
if len(orderbyExpr) != 4 || strings.EqualFold(orderbyExpr[1], "ORDER BY") == false || strings.EqualFold(orderbyExpr[3], "LIMIT") == false{
break
}
orderbyStr = orderbyExpr[2]
} else {
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql)
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{
break
}
@ -144,14 +138,14 @@ func (db *dbmssql) parseSql(sql *string) *string {
}
if haveOrderby {
*sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit)
sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit)
} else {
if first == 0 {
first = limit
} else {
first = limit - first
}
*sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr)
sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr)
}
default:
}

View File

@ -12,22 +12,19 @@ import (
"database/sql"
)
// MySQL接口对象
var linkMysql = &dbmysql{}
// 数据库链接对象
type dbmysql struct {
Db
type dbMysql struct {
*dbBase
}
// 创建SQL操作对象内部采用了lazy link处理
func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) {
func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true",
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset)
}
if db, err := sql.Open("mysql", source); err == nil {
return db, nil
@ -36,17 +33,12 @@ func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) {
}
}
// 获得关键字操作符 - 左
func (db *dbmysql) getQuoteCharLeft () string {
return "`"
}
// 获得关键字操作符 - 右
func (db *dbmysql) getQuoteCharRight () string {
return "`"
// 获得关键字操作符
func (db *dbMysql) getChars () (charLeft string, charRight string) {
return "`", "`"
}
// 在执行sql之前对sql进行进一步处理
func (db *dbmysql) handleSqlBeforeExec(q *string) *string {
return q
func (db *dbMysql) handleSqlBeforeExec(query string) string {
return query
}

View File

@ -21,20 +21,18 @@ import (
"strings"
)
var linkOracle = &dboracle{}
// 数据库链接对象
type dboracle struct {
Db
type dbOracle struct {
*dbBase
}
// 创建SQL操作对象
func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) {
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name)
source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name)
}
if db, err := sql.Open("oci8", source); err == nil {
return db, nil
@ -43,42 +41,37 @@ func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) {
}
}
// 获得关键字操作符 - 左
func (db *dboracle) getQuoteCharLeft() string {
return "\""
}
// 获得关键字操作符 - 右
func (db *dboracle) getQuoteCharRight() string {
return "\""
// 获得关键字操作符
func (db *dbOracle) getChars () (charLeft string, charRight string) {
return "\"", "\""
}
// 在执行sql之前对sql进行进一步处理
func (db *dboracle) handleSqlBeforeExec(q *string) *string {
func (db *dbOracle) handleSqlBeforeExec(query string) string {
index := 0
str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string {
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
index++
return fmt.Sprintf(":%d", index)
})
str, _ = gregex.ReplaceString("\"", "", str)
return db.parseSql(&str)
return db.parseSql(str)
}
//由于ORACLE中对LIMIT和批量插入的语法与MYSQL不一致所以这里需要对LIMIT和批量插入做语法上的转换
func (db *dboracle) parseSql(sql *string) *string {
func (db *dbOracle) parseSql(sql string) string {
//下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理如有LIMIT则将LIMIT的关键字也匹配出
patten := `^\s*(?i)(SELECT)|(INSERT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
if gregex.IsMatchString(patten, *sql) == false {
if gregex.IsMatchString(patten, sql) == false {
fmt.Println("not matched..")
return sql
}
res, err := gregex.MatchAllString(patten, *sql)
res, err := gregex.MatchAllString(patten, sql)
if err != nil {
fmt.Println("MatchString error.", err)
return nil
return ""
}
index := 0
@ -94,11 +87,11 @@ func (db *dboracle) parseSql(sql *string) *string {
}
//取limit前面的字符串
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false {
if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
break
}
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql)
queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{
break
}
@ -118,10 +111,10 @@ func (db *dboracle) parseSql(sql *string) *string {
}
//也可以使用between,据说这种写法的性能会比between好点,里层SQL中的ROWNUM_ >= limit可以缩小查询后的数据集规模
*sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first)
sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first)
case "INSERT":
//获取VALUE的值匹配所有带括号的值,会将INSERT INTO后的值匹配到所以下面的判断语句会判断数组长度是否小于3
valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, *sql)
valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, sql)
if err != nil {
return sql
}
@ -132,17 +125,17 @@ func (db *dboracle) parseSql(sql *string) *string {
}
//获取INTO后面的值
tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, *sql)
tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, sql)
if err != nil {
return sql
}
tableExpr[0] = strings.TrimSpace(tableExpr[0])
*sql = "INSERT ALL"
sql = "INSERT ALL"
for i := 1; i < len(valueExpr); i++ {
*sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0]))
sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0]))
}
*sql += " SELECT 1 FROM DUAL"
sql += " SELECT 1 FROM DUAL"
default:
}

View File

@ -18,22 +18,18 @@ import (
// _ "gitee.com/johng/gf/third/github.com/lib/pq"
// @todo 需要完善replace和save的操作覆盖
// PostgreSQL接口对象
var linkPgsql = &dbpgsql{}
// 数据库链接对象
type dbpgsql struct {
Db
type dbPgsql struct {
*dbBase
}
// 创建SQL操作对象内部采用了lazy link处理
func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) {
func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name)
}
if db, err := sql.Open("postgres", source); err == nil {
return db, nil
@ -42,23 +38,18 @@ func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) {
}
}
// 获得关键字操作符 - 左
func (db *dbpgsql) getQuoteCharLeft () string {
return "\""
}
// 获得关键字操作符 - 右
func (db *dbpgsql) getQuoteCharRight () string {
return "\""
// 获得关键字操作符
func (db *dbPgsql) getChars () (charLeft string, charRight string) {
return "\"", "\""
}
// 在执行sql之前对sql进行进一步处理
func (db *dbpgsql) handleSqlBeforeExec(q *string) *string {
func (db *dbPgsql) handleSqlBeforeExec(query string) string {
reg := regexp.MustCompile("\\?")
index := 0
str := reg.ReplaceAllStringFunc(*q, func (s string) string {
str := reg.ReplaceAllStringFunc(query, func (s string) string {
index ++
return fmt.Sprintf("$%d", index)
})
return &str
return str
}

View File

@ -16,20 +16,18 @@ import (
// Sqlite接口对象
// @author wxkj<wxscz@qq.com>
var linkSqlite = &dbsqlite{}
// 数据库链接对象
type dbsqlite struct {
Db
type dbSqlite struct {
*dbBase
}
func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) {
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = c.Name
source = config.Name
}
if db, err := sql.Open("sqlite3", source); err == nil {
return db, nil
@ -38,20 +36,14 @@ func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) {
}
}
// 获得关键字操作符 - 左
func (db *dbsqlite) getQuoteCharLeft() string {
return "`"
}
// 获得关键字操作符 - 右
func (db *dbsqlite) getQuoteCharRight() string {
return "`"
// 获得关键字操作符
func (db *dbSqlite) getChars () (charLeft string, charRight string) {
return "`", "`"
}
// 在执行sql之前对sql进行进一步处理
// @todo 需要增加对Save方法的支持可使用正则来实现替换
// @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
func (db *dbsqlite) handleSqlBeforeExec(q *string) *string {
return q
func (db *dbSqlite) handleSqlBeforeExec(query string) string {
return query
}

View File

@ -7,128 +7,55 @@
package gdb
import (
"fmt"
"errors"
"strings"
"reflect"
"database/sql"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"gitee.com/johng/gf/g/container/gvar"
)
// 数据库事务对象
type Tx struct {
db *Db
type TX struct {
db DB
tx *sql.Tx
master *sql.DB
}
// 事务操作,提交
func (tx *Tx) Commit() error {
func (tx *TX) Commit() error {
return tx.tx.Commit()
}
// 事务操作,回滚
func (tx *Tx) Rollback() error {
func (tx *TX) Rollback() error {
return tx.tx.Rollback()
}
// (事务)数据库sql查询操作主要执行查询
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
var err error
var rows *sql.Rows
p := tx.db.link.handleSqlBeforeExec(&query)
if tx.db.debug.Val() {
militime1 := gtime.Millisecond()
rows, err = tx.tx.Query(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "TX:Query",
}
tx.db.sqls.Put(s)
tx.db.printSql(s)
} else {
rows, err = tx.tx.Query(*p, args ...)
}
if err == nil {
return rows, nil
} else {
err = tx.db.formatError(err, p, args...)
}
return nil, err
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
return tx.db.doQuery(tx.tx, query, args...)
}
// (事务)执行一条sql并返回执行情况主要用于非查询操作
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
var err error
var result sql.Result
p := tx.db.link.handleSqlBeforeExec(&query)
if tx.db.debug.Val() {
militime1 := gtime.Millisecond()
result, err = tx.tx.Exec(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "TX:Exec",
}
tx.db.sqls.Put(s)
tx.db.printSql(s)
} else {
result, err = tx.tx.Exec(*p, args ...)
}
return result, tx.db.formatError(err, p, args...)
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.db.doExec(tx.tx, query, args...)
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
return tx.db.doPrepare(tx.tx, query)
}
// 数据库查询,获取查询结果集,以列表结构返回
func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) {
// 执行sql
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := tx.Query(query, args ...)
if err != nil || rows == nil {
return nil, err
}
// 列名称列表
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
records := make(Result, 0)
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return records, err
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
for i, col := range values {
v := make([]byte, len(col))
copy(v, col)
row[columns[i]] = gvar.New(v, false)
}
//fmt.Printf("%p\n", row["typeid"])
records = append(records, row)
}
return records, nil
defer rows.Close()
return rowsToResult(rows)
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) {
func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
list, err := tx.GetAll(query, args ...)
if err != nil {
return nil, err
@ -140,7 +67,7 @@ func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) {
}
// 数据库查询获取查询结果记录自动映射数据到给定的struct对象中
func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error {
func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := tx.GetOne(query, args...)
if err != nil {
return err
@ -148,9 +75,8 @@ func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) erro
return one.ToStruct(obj)
}
// 数据库查询,获取查询字段值
func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) {
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
one, err := tx.GetOne(query, args ...)
if err != nil {
return nil, err
@ -162,186 +88,55 @@ func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) {
}
// 数据库查询,获取查询数量
func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) {
val, err := tx.GetValue(query, args ...)
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := tx.GetValue(query, args ...)
if err != nil {
return 0, err
}
return gconv.Int(val), nil
}
// 数据表查询其中tables可以是多个联表查询语句这种查询方式较复杂建议使用链式操作
func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
if condition != nil {
s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition))
}
if len(groupBy) > 0 {
s += fmt.Sprintf("GROUP BY %s ", groupBy)
}
if len(orderBy) > 0 {
s += fmt.Sprintf("ORDER BY %s ", orderBy)
}
if limit > 0 {
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
}
return tx.GetAll(s, args ... )
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
return tx.tx.Prepare(query)
}
// insert、replace, save ignore操作
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) {
var keys []string
var values []string
var params []interface{}
for k, v := range data {
keys = append(keys, tx.db.charl + k + tx.db.charr)
values = append(values, "?")
params = append(params, v)
}
operation := tx.db.getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range data {
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
return tx.Exec(
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(keys, ","),
strings.Join(values, ","),
updatestr),
params...
)
return value.Int(), nil
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (tx *Tx) Insert(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_INSERT)
func (tx *TX) Insert(table string, data Map) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (tx *Tx) Replace(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_REPLACE)
func (tx *TX) Replace(table string, data Map) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (tx *Tx) Save(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_SAVE)
}
// 批量写入数据
func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
var keys []string
var values []string
var bvalues []string
var params []interface{}
var result sql.Result
var size = len(list)
// 判断长度
if size < 1 {
return result, errors.New("empty data list")
}
// 首先获取字段名称及记录长度
for k, _ := range list[0] {
keys = append(keys, k)
values = append(values, "?")
}
keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr
valueHolderStr := "(" + strings.Join(values, ",") + ")"
// 操作判断
operation := tx.db.getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < size; i++ {
for _, k := range keys {
params = append(params, list[i][k])
}
bvalues = append(bvalues, valueHolderStr)
if len(bvalues) == batch {
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
if err != nil {
return result, err
}
result = r
params = params[:0]
bvalues = bvalues[:0]
}
}
// 处理最后不构成指定批量的数据
if len(bvalues) > 0 {
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
if err != nil {
return result, err
}
result = r
}
return result, nil
func (tx *TX) Save(table string, data Map) (sql.Result, error) {
return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE)
}
// CURD操作:批量数据指定批次量写入
func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_INSERT)
func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_REPLACE)
func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_SAVE)
func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) {
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE)
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
var params []interface{}
var updates string
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
keys := refValue.MapKeys()
for _, k := range keys {
fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr))
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
updates = strings.Join(fields, ",")
}
} else {
updates = gconv.String(data)
}
for _, v := range args {
params = append(params, gconv.String(v))
}
return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...)
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.db.doUpdate(tx.tx, table, data, condition, args ...)
}
// CURD操作:删除数据
func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...)
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.db.doDelete(tx.tx, table, condition, args ...)
}

View File

@ -0,0 +1,52 @@
package gdb_test
import (
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g/util/gtest"
)
var (
// 数据库对象/接口
db gdb.DB
)
// 初始化连接参数。
// 测试前需要修改连接参数。
func init() {
gdb.AddDefaultConfigNode(gdb.ConfigNode{
Host: "127.0.0.1",
Port: "3306",
User: "root",
Pass: "",
Name: "",
Type: "mysql",
Role: "master",
Charset: "utf8",
Priority: 1,
})
if r, err := gdb.New(); err != nil {
gtest.Fatal(err)
} else {
db = r
}
// 准备测试数据结构
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `test` CHARACTER SET UTF8"); err != nil {
gtest.Fatal(err)
}
db.SetSchema("test")
if _, err := db.Exec("DROP TABLE IF EXISTS `user`"); err != nil {
gtest.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE user (
id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID',
passport varchar(45) NOT NULL COMMENT '账号',
password char(32) NOT NULL COMMENT '密码',
nickname varchar(45) NOT NULL COMMENT '昵称',
create_time timestamp NOT NULL COMMENT '创建时间/注册时间',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`); err != nil {
gtest.Fatal(err)
}
}

View File

@ -0,0 +1,172 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
func TestDbBase_Query(t *testing.T) {
if _, err := db.Query("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := db.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
}
func TestDbBase_Exec(t *testing.T) {
if _, err := db.Exec("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := db.Exec("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
}
func TestDbBase_Prepare(t *testing.T) {
st, err := db.Prepare("SELECT 100")
if err != nil {
gtest.Fatal(err)
}
rows, err := st.Query()
if err != nil {
gtest.Fatal(err)
}
array, err := rows.Columns()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(array[0], "100")
if err := rows.Close(); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Insert(t *testing.T) {
if _, err := db.Insert("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_BatchInsert(t *testing.T) {
if _, err := db.BatchInsert("user", g.List {
{
"id" : 2,
"passport" : "t2",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Save(t *testing.T) {
if _, err := db.Save("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Replace(t *testing.T) {
if _, err := db.Save("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T111",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Update(t *testing.T) {
if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil {
gtest.Fatal(err)
} else {
n, _ := result.RowsAffected()
gtest.Assert(n, 1)
}
}
func TestDbBase_GetAll(t *testing.T) {
if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(len(result), 1)
}
}
func TestDbBase_GetOne(t *testing.T) {
if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil {
gtest.Fatal(err)
} else {
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T111")
}
}
func TestDbBase_GetValue(t *testing.T) {
if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.Int(), 3)
}
}
func TestDbBase_GetCount(t *testing.T) {
if count, err := db.GetCount("SELECT * FROM user"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(count, 3)
}
}
func TestDbBase_GetStruct(t *testing.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01")
}
}
func TestDbBase_Delete(t *testing.T) {
if result, err := db.Delete("user", nil); err != nil {
gtest.Fatal(err)
} else {
n, _ := result.RowsAffected()
gtest.Assert(n, 3)
}
}

View File

@ -0,0 +1,220 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
func TestModel_Insert(t *testing.T) {
result, err := db.Table("user").Filter().Data(g.Map{
"id" : 1,
"uid" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}).Insert()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.LastInsertId()
gtest.Assert(n, 1)
}
func TestModel_Batch(t *testing.T) {
result, err := db.Table("user").Filter().Data(g.List{
{
"id" : 2,
"uid" : 2,
"passport" : "t2",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"uid" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}).Batch(10).Insert()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Replace(t *testing.T) {
result, err := db.Table("user").Data(g.Map{
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}).Replace()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Save(t *testing.T) {
result, err := db.Table("user").Data(g.Map{
"id" : 1,
"passport" : "t111",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T111",
"create_time" : gtime.Now().String(),
}).Save()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Update(t *testing.T) {
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 1)
}
func TestModel_Clone(t *testing.T) {
md := db.Table("user").Where("id IN(?)", g.Slice{1,3})
count, err := md.Count()
if err != nil {
gtest.Fatal(err)
}
record, err := md.OrderBy("id DESC").One()
if err != nil {
gtest.Fatal(err)
}
result, err := md.OrderBy("id ASC").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(count, 2)
gtest.Assert(record["id"].Int(), 3)
gtest.Assert(len(result), 2)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 3)
}
func TestModel_All(t *testing.T) {
result, err := db.Table("user").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
}
func TestModel_One(t *testing.T) {
record, err := db.Table("user").Where("id", 1).One()
if err != nil {
gtest.Fatal(err)
}
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T111")
}
func TestModel_Value(t *testing.T) {
value, err := db.Table("user").Fields("nickname").Where("id", 1).Value()
if err != nil {
gtest.Fatal(err)
}
if value == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(value.String(), "T111")
}
func TestModel_Count(t *testing.T) {
count, err := db.Table("user").Count()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(count, 3)
}
func TestModel_Select(t *testing.T) {
result, err := db.Table("user").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
}
func TestModel_Struct(t *testing.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
err := db.Table("user").Where("id=1").Struct(user)
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(user.NickName, "T111")
}
func TestModel_OrderBy(t *testing.T) {
result, err := db.Table("user").OrderBy("id DESC").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["nickname"].String(), "T3")
}
func TestModel_GroupBy(t *testing.T) {
result, err := db.Table("user").GroupBy("id").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["nickname"].String(), "T111")
}
func TestModel_Where1(t *testing.T) {
result, err := db.Table("user").Where("id IN(?)", g.Slice{1,3}).OrderBy("id ASC").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 2)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 3)
}
func TestModel_Where2(t *testing.T) {
result, err := db.Table("user").Where("nickname=? AND id IN(?)", "T3", g.Slice{1,3}).OrderBy("id ASC").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 1)
gtest.Assert(result[0]["id"].Int(), 3)
}
func TestModel_Delete(t *testing.T) {
result, err := db.Table("user").Delete()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 3)
}

View File

@ -0,0 +1,372 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
func TestTX_Query(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if rows, err := tx.Query("SELECT ?", 1); err != nil {
gtest.Fatal(err)
} else {
rows.Close()
}
if _, err := tx.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Exec(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Commit(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Rollback(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if err := tx.Rollback(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Prepare(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
st, err := tx.Prepare("SELECT 100")
if err != nil {
gtest.Fatal(err)
}
rows, err := st.Query()
if err != nil {
gtest.Fatal(err)
}
array, err := rows.Columns()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(array[0], "100")
if err := rows.Close(); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Insert(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Insert("user", g.Map {
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 1)
}
}
func TestTX_BatchInsert(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchInsert("user", g.List {
{
"id" : 2,
"passport" : "t",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 3)
}
}
func TestTX_BatchReplace(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchReplace("user", g.List {
{
"id" : 2,
"passport" : "t2",
"password" : "p2",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 4,
"passport" : "t4",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T4",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
// 数据数量
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 4)
}
// 检查replace后的数值
if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "p2")
}
}
func TestTX_BatchSave(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchSave("user", g.List {
{
"id" : 4,
"passport" : "t4",
"password" : "p4",
"nickname" : "T4",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
// 数据数量
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 4)
}
// 检查replace后的数值
if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "p4")
}
}
func TestTX_Replace(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Replace("user", g.Map {
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Rollback(); err != nil {
gtest.Fatal(err)
}
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "T1")
}
}
func TestTX_Save(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Save("user", g.Map {
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "T11")
}
}
func TestTX_GetAll(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(len(result), 1)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetOne(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil {
gtest.Fatal(err)
} else {
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T2")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetValue(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.Int(), 3)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetCount(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if count, err := tx.GetCount("SELECT * FROM user"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(count, 4)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetStruct(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(user.NickName, "T11")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Delete(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Delete("user", nil); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 0)
}
}

View File

@ -4,16 +4,15 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// Kafka Client.
// Package gkafka provides producer and consumer client for kafka server/Kafka客户端.
package gkafka
import (
"gitee.com/johng/gf/g/os/glog"
"time"
"strings"
"gitee.com/johng/gf/third/github.com/Shopify/sarama"
"gitee.com/johng/gf/third/github.com/johng-cn/sarama-cluster"
"errors"
"strings"
"time"
)
var (
@ -177,8 +176,6 @@ func (client *Client) Receive() (*Message, error) {
case <-notifyChan:
}
}
return nil, errors.New("unknown error")
}
// Send data to kafka in synchronized way.

View File

@ -487,7 +487,6 @@ func (j *Json) convertValue(value interface{}) interface{} {
v, _ := Decode(b)
return v
}
return value
}
// 用于Set方法中对指针指向的内存地址进行赋值

View File

@ -9,18 +9,18 @@
package gins
import (
"fmt"
"gitee.com/johng/gf/g/container/gmap"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g/database/gredis"
"gitee.com/johng/gf/g/os/gcfg"
"gitee.com/johng/gf/g/os/gcmd"
"gitee.com/johng/gf/g/os/genv"
"gitee.com/johng/gf/g/os/gfile"
"gitee.com/johng/gf/g/os/gfsnotify"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gview"
"gitee.com/johng/gf/g/os/gfile"
"gitee.com/johng/gf/g/container/gmap"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g/os/gfsnotify"
"fmt"
"gitee.com/johng/gf/g/database/gredis"
"gitee.com/johng/gf/g/util/gregex"
)
@ -76,9 +76,7 @@ func View(name...string) *gview.View {
if path == "" {
path = genv.Get("GF_VIEWPATH")
if path == "" {
if gfile.SelfDir() != gfile.TempDir() {
path = gfile.SelfDir()
}
path = gfile.SelfDir()
}
}
view := gview.New(path)
@ -105,9 +103,7 @@ func Config(file...string) *gcfg.Config {
if path == "" {
path = genv.Get("GF_CFGPATH")
if path == "" {
if gfile.SelfDir() != gfile.TempDir() {
path = gfile.SelfDir()
}
path = gfile.SelfDir()
}
}
config := gcfg.New(path, configFile)
@ -120,7 +116,7 @@ func Config(file...string) *gcfg.Config {
}
// 数据库操作对象,使用了连接池
func Database(name...string) *gdb.Db {
func Database(name...string) gdb.DB {
config := Config()
group := gdb.DEFAULT_GROUP_NAME
if len(name) > 0 {
@ -128,65 +124,67 @@ func Database(name...string) *gdb.Db {
}
key := fmt.Sprintf("%s.%s", gFRAME_CORE_COMPONENT_NAME_DATABASE, group)
db := instances.GetOrSetFuncLock(key, func() interface{} {
m := config.GetMap("database")
if m == nil {
glog.Error(`database init failed: "database" node not found, is config file or configuration missing?`)
return nil
}
for group, v := range m {
cg := gdb.ConfigGroup{}
if list, ok := v.([]interface{}); ok {
for _, nodev := range list {
node := gdb.ConfigNode{}
nodem := nodev.(map[string]interface{})
if value, ok := nodem["host"]; ok {
node.Host = gconv.String(value)
}
if value, ok := nodem["port"]; ok {
node.Port = gconv.String(value)
}
if value, ok := nodem["user"]; ok {
node.User = gconv.String(value)
}
if value, ok := nodem["pass"]; ok {
node.Pass = gconv.String(value)
}
if value, ok := nodem["name"]; ok {
node.Name = gconv.String(value)
}
if value, ok := nodem["type"]; ok {
node.Type = gconv.String(value)
}
if value, ok := nodem["role"]; ok {
node.Role = gconv.String(value)
}
if value, ok := nodem["charset"]; ok {
node.Charset = gconv.String(value)
}
if value, ok := nodem["priority"]; ok {
node.Priority = gconv.Int(value)
}
if value, ok := nodem["linkinfo"]; ok {
node.Linkinfo = gconv.String(value)
}
if value, ok := nodem["max-idle"]; ok {
node.MaxIdleConnCount = gconv.Int(value)
}
if value, ok := nodem["max-open"]; ok {
node.MaxOpenConnCount = gconv.Int(value)
}
if value, ok := nodem["max-lifetime"]; ok {
node.MaxConnLifetime = gconv.Int(value)
}
cg = append(cg, node)
}
if gdb.GetConfig(group) == nil {
m := config.GetMap("database")
if m == nil {
glog.Error(`database init failed: "database" node not found, is config file or configuration missing?`)
return nil
}
gdb.AddConfigGroup(group, cg)
for group, v := range m {
cg := gdb.ConfigGroup{}
if list, ok := v.([]interface{}); ok {
for _, nodev := range list {
node := gdb.ConfigNode{}
nodem := nodev.(map[string]interface{})
if value, ok := nodem["host"]; ok {
node.Host = gconv.String(value)
}
if value, ok := nodem["port"]; ok {
node.Port = gconv.String(value)
}
if value, ok := nodem["user"]; ok {
node.User = gconv.String(value)
}
if value, ok := nodem["pass"]; ok {
node.Pass = gconv.String(value)
}
if value, ok := nodem["name"]; ok {
node.Name = gconv.String(value)
}
if value, ok := nodem["type"]; ok {
node.Type = gconv.String(value)
}
if value, ok := nodem["role"]; ok {
node.Role = gconv.String(value)
}
if value, ok := nodem["charset"]; ok {
node.Charset = gconv.String(value)
}
if value, ok := nodem["priority"]; ok {
node.Priority = gconv.Int(value)
}
if value, ok := nodem["linkinfo"]; ok {
node.Linkinfo = gconv.String(value)
}
if value, ok := nodem["max-idle"]; ok {
node.MaxIdleConnCount = gconv.Int(value)
}
if value, ok := nodem["max-open"]; ok {
node.MaxOpenConnCount = gconv.Int(value)
}
if value, ok := nodem["max-lifetime"]; ok {
node.MaxConnLifetime = gconv.Int(value)
}
cg = append(cg, node)
}
}
gdb.AddConfigGroup(group, cg)
}
// 使用gfsnotify进行文件监控当配置文件有任何变化时清空数据库配置缓存
gfsnotify.Add(config.GetFilePath(), func(event *gfsnotify.Event) {
instances.Remove(key)
})
}
// 使用gfsnotify进行文件监控当配置文件有任何变化时清空数据库配置缓存
gfsnotify.Add(config.GetFilePath(), func(event *gfsnotify.Event) {
instances.Remove(key)
})
if db, err := gdb.New(name...); err == nil {
return db
} else {
@ -195,7 +193,7 @@ func Database(name...string) *gdb.Db {
return nil
})
if db != nil {
return db.(*gdb.Db)
return db.(gdb.DB)
}
return nil
}

View File

@ -32,7 +32,7 @@ func (c *Controller) Init(r *ghttp.Request) {
}
// 控制器结束请求接口方法
func (c *Controller) Shut(r *ghttp.Request) {
func (c *Controller) Shut() {
}

23
g/g.go
View File

@ -10,14 +10,27 @@ package g
import "gitee.com/johng/gf/g/container/gvar"
// 框架动态变量可以用该类型替代interface{}类型
type Var = gvar.Var
type Var = gvar.Var
// 常用map数据结构(使用别名)
type Map = map[string]interface{}
type Map = map[string]interface{}
type MapStrStr = map[string]string
type MapStrInt = map[string]int
type MapIntStr = map[int]string
type MapIntInt = map[int]int
// 常用list数据结构(使用别名)
type List = []Map
type List = []Map
type ListStrStr = []map[string]string
type ListStrInt = []map[string]int
type ListIntStr = []map[int]string
type ListIntInt = []map[int]int
// 常用slice数据结构(使用别名)
type Slice = []interface{}
type Array = Slice
type Slice = []interface{}
type SliceStr = []string
type SliceInt = []int
type Array = Slice
type ArrayStr = SliceStr
type ArrayInt = SliceInt

View File

@ -23,12 +23,12 @@ func Server(name...interface{}) *ghttp.Server {
}
// TCPServer单例对象
func TcpServer(name...interface{}) *gtcp.Server {
func TCPServer(name...interface{}) *gtcp.Server {
return gtcp.GetServer(name...)
}
// UDPServer单例对象
func UdpServer(name...interface{}) *gudp.Server {
func UDPServer(name...interface{}) *gudp.Server {
return gudp.GetServer(name...)
}
@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config {
}
// 数据库操作对象,使用了连接池
func Database(name...string) *gdb.Db {
func Database(name...string) gdb.DB {
return gins.Database(name...)
}
// (别名)Database
func DB(name...string) *gdb.Db {
func DB(name...string) gdb.DB {
return gins.Database(name...)
}

View File

@ -10,5 +10,5 @@ package ghttp
// 控制器接口
type Controller interface {
Init(*Request)
Shut(*Request)
Shut()
}

View File

@ -29,7 +29,7 @@ func (r *Request) GetRequest(key string, def ... []string) []string {
func (r *Request) GetRequestVar(key string, def ... interface{}) gvar.VarRead {
value := r.GetRequest(key)
if value != nil {
return gvar.New(value, false)
return gvar.New(value[0], false)
}
if len(def) > 0 {
return gvar.New(def[0], false)

View File

@ -15,6 +15,7 @@ import (
"gitee.com/johng/gf/g/container/gtype"
"gitee.com/johng/gf/g/os/gcache"
"gitee.com/johng/gf/g/os/genv"
"gitee.com/johng/gf/g/os/gfile"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gproc"
"gitee.com/johng/gf/g/os/gtime"
@ -90,13 +91,13 @@ type (
}
// pattern与回调函数的绑定map
handlerMap map[string]*handlerItem
handlerMap = map[string]*handlerItem
// HTTP注册函数
HandlerFunc func(r *Request)
HandlerFunc = func(r *Request)
// 文件描述符map
listenerFdMap map[string]string
listenerFdMap = map[string]string
)
const (
@ -254,6 +255,10 @@ func (s *Server) Start() error {
}
})
}
// 是否处于开发环境
if gfile.MainPkgPath() != "" {
glog.Debug("GF notices that you're in develop environment, so error logs are auto enabled to stdout.")
}
// 打印展示路由表
s.DumpRoutesMap()

View File

@ -160,35 +160,46 @@ func (s *Server) searchStaticFile(uri string) (filePath string, isDir bool) {
// 初始化控制器
func (s *Server) callServeHandler(h *handlerItem, r *Request) {
if h.faddr == nil {
c := reflect.New(h.ctype)
s.niceCall(func() {
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(r)})
})
s.niceCall(func() {
c.MethodByName(h.fname).Call(nil)
})
s.niceCall(func() {
c.MethodByName("Shut").Call(nil)
})
} else {
if h.finit != nil {
s.niceCall(func() {
h.finit(r)
})
}
s.niceCall(func() {
h.faddr(r)
})
if h.fshut != nil {
s.niceCall(func() {
h.fshut(r)
})
}
}
}
// 友好地调用方法
func (s *Server) niceCall(f func()) {
defer func() {
if e := recover(); e != nil && e != gEXCEPTION_EXIT {
panic(e)
}
}()
if h.faddr == nil {
// 新建一个控制器对象处理请求
c := reflect.New(h.ctype)
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(r)})
if !r.IsExited() {
c.MethodByName(h.fname).Call(nil)
c.MethodByName("Shut").Call([]reflect.Value{reflect.ValueOf(r)})
}
} else {
// 是否有初始化及完成回调方法
if h.finit != nil {
h.finit(r)
}
if !r.IsExited() {
h.faddr(r)
if h.fshut != nil {
h.fshut(r)
}
}
}
f()
}
// http server静态文件处理path可以为相对路径也可以为绝对路径
func (s *Server)serveFile(r *Request, path string) {
func (s *Server) serveFile(r *Request, path string) {
f, err := os.Open(path)
if err != nil {
r.Response.WriteStatus(http.StatusForbidden)

View File

@ -9,6 +9,7 @@ package ghttp
import (
"fmt"
"gitee.com/johng/gf/g/os/gfile"
"net/http"
)
@ -36,7 +37,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
r.Response.WriteStatus(http.StatusInternalServerError)
// 错误输出默认是开启的
if !s.IsErrorLogEnabled() {
if !s.IsErrorLogEnabled() && gfile.MainPkgPath() == "" {
return
}
@ -56,5 +57,9 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content)
} else {
s.logger.Cat("error").Backtrace(true, 2).Error(content)
// 开发环境下(MainPkgPath)自动输出错误信息到标准输出
if gfile.MainPkgPath() != "" {
s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content)
}
}
}

View File

@ -20,24 +20,28 @@ import (
// 解析pattern
func (s *Server)parsePattern(pattern string) (domain, method, uri string, err error) {
uri = pattern
func (s *Server)parsePattern(pattern string) (domain, method, path string, err error) {
path = strings.TrimSpace(pattern)
domain = gDEFAULT_DOMAIN
method = gDEFAULT_METHOD
if array, err := gregex.MatchString(`([a-zA-Z]+):(.+)`, pattern); len(array) > 1 && err == nil {
method = array[1]
uri = array[2]
path = strings.TrimSpace(array[2])
if v := strings.TrimSpace(array[1]); v != "" {
method = v
}
}
if array, err := gregex.MatchString(`(.+)@([\w\.\-]+)`, uri); len(array) > 1 && err == nil {
uri = array[1]
domain = array[2]
if array, err := gregex.MatchString(`(.+)@([\w\.\-]+)`, path); len(array) > 1 && err == nil {
path = strings.TrimSpace(array[1])
if v := strings.TrimSpace(array[2]); v != "" {
domain = v
}
}
if uri == "" {
if path == "" {
err = errors.New("invalid pattern")
}
// 去掉末尾的"/"符号,与路由匹配时处理一致
if uri != "/" {
uri = strings.TrimRight(uri, "/")
if path != "/" {
path = strings.TrimRight(path, "/")
}
return
}
@ -293,7 +297,6 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
regrule += `/[^/]+`
break
}
fallthrough
case '*':
if len(v) > 1 {
regrule += `/{0,1}(.*)`
@ -303,7 +306,6 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string)
regrule += `/{0,1}.*`
break
}
fallthrough
default:
// 特殊字符替换
v = gstr.ReplaceByMap(v, map[string]string{

View File

@ -0,0 +1,214 @@
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// 分组路由管理.
package ghttp
import (
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/util/gconv"
"reflect"
"strings"
)
// 分组路由对象
type RouterGroup struct {
server *Server // Server
domain *Domain // Domain
prefix string // URI前缀
}
// 分组路由批量绑定项
type GroupItem = []interface{}
// 获取分组路由对象
func (s *Server) Group(prefix...string) *RouterGroup {
if len(prefix) > 0 {
return &RouterGroup{
server : s,
prefix : prefix[0],
}
}
return &RouterGroup{}
}
// 获取分组路由对象
func (d *Domain) Group(prefix...string) *RouterGroup {
if len(prefix) > 0 {
return &RouterGroup{
domain : d,
prefix : prefix[0],
}
}
return &RouterGroup{}
}
// 执行分组路由批量绑定
func (g *RouterGroup) Bind(group string, items []GroupItem) {
for _, item := range items {
if len(item) < 3 {
glog.Fatalfln("invalid router item: %s", item)
}
if strings.EqualFold(gconv.String(item[0]), "REST") {
g.bind("REST", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2])
} else {
if len(item) > 3 {
g.bind("HANDLER", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2], item[3])
} else {
g.bind("HANDLER", gconv.String(item[0]) + ":" + gconv.String(item[1]), item[2])
}
}
}
}
// 绑定所有的HTTP Method请求方式
func (g *RouterGroup) ALL(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", gDEFAULT_METHOD + ":" + pattern, object, params...)
}
func (g *RouterGroup) GET(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "GET:" + pattern, object, params...)
}
func (g *RouterGroup) PUT(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "PUT:" + pattern, object, params...)
}
func (g *RouterGroup) POST(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "POST:" + pattern, object, params...)
}
func (g *RouterGroup) DELETE(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "DELETE:" + pattern, object, params...)
}
func (g *RouterGroup) PATCH(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "PATCH:" + pattern, object, params...)
}
func (g *RouterGroup) HEAD(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "HEAD:" + pattern, object, params...)
}
func (g *RouterGroup) CONNECT(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "CONNECT:" + pattern, object, params...)
}
func (g *RouterGroup) OPTIONS(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "OPTIONS:" + pattern, object, params...)
}
func (g *RouterGroup) TRACE(pattern string, object interface{}, params...interface{}) {
g.bind("HANDLER", "TRACE:" + pattern, object, params...)
}
// REST路由注册
func (g *RouterGroup) REST(pattern string, object interface{}) {
g.bind("REST", pattern, object)
}
// 执行路由绑定
func (g *RouterGroup) bind(bindType string, pattern string, object interface{}, params...interface{}) {
// 注册路由处理
if len(g.prefix) > 0 {
domain, method, path, err := g.server.parsePattern(pattern)
if err != nil {
glog.Fatalfln("invalid pattern: %s", pattern)
}
if bindType == "HANDLER" {
pattern = g.server.serveHandlerKey(method, g.prefix + "/" + strings.TrimLeft(path, "/"), domain)
} else {
pattern = g.prefix + "/" + strings.TrimLeft(path, "/")
}
}
methods := gconv.Strings(params)
// 判断是否事件回调注册
if _, ok := object.(HandlerFunc); ok && len(methods) > 0 {
bindType = "HOOK"
}
switch bindType {
case "HANDLER":
if h, ok := object.(HandlerFunc); ok {
if g.server != nil {
g.server.BindHandler(pattern, h)
} else {
g.domain.BindHandler(pattern, h)
}
} else if g.isController(object) {
if len(methods) > 0 {
if g.server != nil {
g.server.BindControllerMethod(pattern, object.(Controller), methods[0])
} else {
g.domain.BindControllerMethod(pattern, object.(Controller), methods[0])
}
} else {
if g.server != nil {
g.server.BindController(pattern, object.(Controller))
} else {
g.domain.BindController(pattern, object.(Controller))
}
}
} else {
if len(methods) > 0 {
if g.server != nil {
g.server.BindObjectMethod(pattern, object, methods[0])
} else {
g.domain.BindObjectMethod(pattern, object, methods[0])
}
} else {
if g.server != nil {
g.server.BindObject(pattern, object)
} else {
g.domain.BindObject(pattern, object)
}
}
}
case "REST":
if g.isController(object) {
if g.server != nil {
g.server.BindControllerRest(pattern, object.(Controller))
} else {
g.domain.BindControllerRest(pattern, object.(Controller))
}
} else {
if g.server != nil {
g.server.BindObjectRest(pattern, object)
} else {
g.domain.BindObjectRest(pattern, object)
}
}
case "HOOK":
if h, ok := object.(HandlerFunc); ok {
if g.server != nil {
g.server.BindHookHandler(pattern, methods[0], h)
} else {
g.domain.BindHookHandler(pattern, methods[0], h)
}
} else {
glog.Fatalfln("invalid hook handler for pattern:%s", pattern)
}
}
}
// 判断给定对象是否控制器对象:
// 控制器必须包含以下公开的属性对象Request/Response/Server/Cookie/Session/View.
func (g *RouterGroup) isController(value interface{}) bool {
// 首先判断是否满足控制器接口定义
if _, ok := value.(Controller); !ok {
return false
}
// 其次检查控制器的必需属性
v := reflect.ValueOf(value)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.FieldByName("Request").IsValid() && v.FieldByName("Response").IsValid() &&
v.FieldByName("Server").IsValid() && v.FieldByName("Cookie").IsValid() &&
v.FieldByName("Session").IsValid() && v.FieldByName("View").IsValid() {
return true
}
return false
}

View File

@ -25,7 +25,6 @@ func (s *Server)BindHookHandler(pattern string, hook string, handler HandlerFunc
fname : "",
faddr : handler,
}, hook)
return nil
}
// 通过map批量绑定回调函数

View File

@ -82,5 +82,4 @@ func (s *Server) Run() error {
go s.handler(NewConnByNetConn(conn))
}
}
return nil
}

View File

@ -85,7 +85,6 @@ func (c *Conn) Send(data []byte, retry...Retry) error {
}
}
}
return nil
}
// 接收数据

View File

@ -77,5 +77,4 @@ func (s *Server) Run() error {
for {
s.handler(NewConnByNetConn(conn))
}
return nil
}

View File

@ -63,8 +63,8 @@ func newMemCache(lruCap...int) *memCache {
closed : gtype.NewBool(),
}
if len(lruCap) > 0 {
c.lru = newMemCacheLru(c)
c.cap = lruCap[0]
c.lru = newMemCacheLru(c)
}
return c
}

View File

@ -370,7 +370,7 @@ func homeWindows() (string, error) {
return home, nil
}
// 获取入口函数文件所在目录(main包文件目录)
// 获取入口函数文件所在目录(main包文件目录),
// **仅对源码开发环境有效(即仅对生成该可执行文件的系统下有效)**
func MainPkgPath() string {
path := mainPkgPath.Val()
@ -401,6 +401,7 @@ func MainPkgPath() string {
if p == f {
break
}
// 会自动扫描源码寻找main包
if paths, err := ScanDir(p, "*.go"); err == nil && len(paths) > 0 {
for _, path := range paths {
if gregex.IsMatchString(`package\s+main`, GetContents(path)) {

View File

@ -105,8 +105,6 @@ func GetNextCharOffsetByPath(path string, char byte, start int64) int64 {
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
defer f.Close()
return GetNextCharOffset(f, char, start)
} else {
panic(err)
}
return -1
}
@ -124,8 +122,6 @@ func GetBinContentsTilCharByPath(path string, char byte, start int64) ([]byte, i
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
defer f.Close()
return GetBinContentsTilChar(f, char, start)
} else {
panic(err)
}
return nil, -1
}
@ -144,8 +140,6 @@ func GetBinContentsByTwoOffsetsByPath(path string, start int64, end int64) []byt
if f, err := OpenWithFlagPerm(path, os.O_RDONLY, gDEFAULT_PERM); err == nil {
defer f.Close()
return GetBinContentsByTwoOffsets(f, start, end)
} else {
panic(err)
}
return nil
}

View File

@ -7,7 +7,6 @@
package gfsnotify
import (
"fmt"
"gitee.com/johng/gf/g/container/glist"
)
@ -32,8 +31,8 @@ func (w *Watcher) startWatchLoop() {
return struct {}{}
}, REPEAT_EVENT_FILTER_INTERVAL)
case err := <- w.watcher.Errors:
fmt.Errorf("error: %s\n" + err.Error());
case <- w.watcher.Errors:
//fmt.Fprintf(os.Stderr, "[gfsnotify] error: %s\n", err.Error())
}
}
}()

View File

@ -31,7 +31,7 @@ type Logger struct {
file *gtype.String // 日志文件名称格式
level *gtype.Int // 日志输出等级
btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数
btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印
btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1表示默认打印逻辑 - 错误才打印)
printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等)
alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端
}
@ -65,7 +65,7 @@ func New() *Logger {
file : gtype.NewString(gDEFAULT_FILE_FORMAT),
level : gtype.NewInt(defaultLevel.Val()),
btSkip : gtype.NewInt(),
btEnabled : gtype.NewBool(true),
btStatus : gtype.NewInt(-1),
printHeader : gtype.NewBool(true),
alsoStdPrint : gtype.NewBool(true),
}
@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger {
file : l.file.Clone(),
level : l.level.Clone(),
btSkip : l.btSkip.Clone(),
btEnabled : l.btEnabled.Clone(),
btStatus : l.btStatus.Clone(),
printHeader : l.printHeader.Clone(),
alsoStdPrint : l.alsoStdPrint.Clone(),
}
@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) {
}
func (l *Logger) SetBacktrace(enabled bool) {
l.btEnabled.Set(enabled)
if enabled {
l.btStatus.Set(1)
} else {
l.btStatus.Set(0)
}
}
// 设置BacktraceSkip
@ -136,6 +141,13 @@ func (l *Logger) getFilePointer() *gfpool.File {
file, _ := gregex.ReplaceStringFunc(`{.+?}`, l.file.Val(), func(s string) string {
return gtime.Now().Format(strings.Trim(s, "{}"))
})
// 如果日志目录不存在则创建目录路径
if !gfile.Exists(path) {
if err := gfile.Mkdir(path); err != nil {
fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error()))
return nil
}
}
fpath := path + gfile.Separator + file
if fp, err := gfpool.Open(fpath, gDEFAULT_FILE_POOL_FLAGS, gDEFAULT_FPOOL_PERM, gDEFAULT_FPOOL_EXPIRE); err == nil {
return fp
@ -151,7 +163,7 @@ func (l *Logger) SetPath(path string) error {
// 如果目录不存在,则递归创建
if !gfile.Exists(path) {
if err := gfile.Mkdir(path); err != nil {
fmt.Fprintln(os.Stderr, fmt.Sprintf(`glog mkdir "%s" failed: %s`, path, err.Error()))
fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error()))
return err
}
}
@ -220,24 +232,35 @@ func (l *Logger) stdPrint(s string) {
// 核心打印数据方法(标准错误)
func (l *Logger) errPrint(s string) {
// 记录调用回溯信息
if l.btEnabled.Val() {
tracestr := l.GetBacktrace()
if tracestr != "" {
backtrace := "Backtrace:" + ln + tracestr
if s[len(s) - 1] == byte('\n') {
s = s + backtrace + ln
} else {
s = s + ln + backtrace + ln
}
}
status := l.btStatus.Val()
if status == -1 || status == 1 {
s = l.appendBacktrace(s)
}
// 防止串日志情况这里不使用stderr而是使用stdout
l.print(os.Stdout, s)
}
// 输出内容中添加回溯信息
func (l *Logger) appendBacktrace(s string, skip...int) string {
trace := l.GetBacktrace(skip...)
if trace != "" {
backtrace := "Backtrace:" + ln + trace
if len(s) > 0 {
if s[len(s)-1] == byte('\n') {
s = s + backtrace + ln
} else {
s = s + ln + backtrace + ln
}
} else {
s = backtrace
}
}
return s
}
// 直接打印回溯信息参数skip表示调用端往上多少级开始回溯
func (l *Logger) PrintBacktrace(skip...int) {
l.Println(l.GetBacktrace(skip...))
l.Println(l.appendBacktrace("", skip...))
}
// 获取文件调用回溯字符串参数skip表示调用端往上多少级开始回溯

View File

@ -121,7 +121,6 @@ func getShell() string {
}
return path
}
return ""
}
// 获取当前系统默认shell执行指令的option参数
@ -132,7 +131,6 @@ func getShellOption() string {
default:
return "-c"
}
return ""
}
// 从环境变量PATH中搜索可执行文件

View File

@ -62,7 +62,10 @@ func startTcpListening() {
// TCP数据通信处理回调函数
func tcpServiceHandler(conn *gtcp.Conn) {
var retry = gtcp.Retry{3, 10}
retry := gtcp.Retry {
Count : 3,
Interval: 10,
}
for {
var result []byte
buffer, err := conn.Recv(-1, retry)
@ -97,32 +100,32 @@ func tcpServiceHandler(conn *gtcp.Conn) {
}
// 数据解包,防止黏包
// 数据格式:总长度(24bit)|发送进程PID(16bit)|接收进程PID(16bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
// 数据格式:总长度(24bit)|发送进程PID(24bit)|接收进程PID(24bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
func bufferToMsgs(buffer []byte) []*Msg {
s := 0
msgs := make([]*Msg, 0)
for s < len(buffer) {
// 长度解析及校验
length := gbinary.DecodeToInt(buffer[s : s + 3])
if length < 12 || length > len(buffer) {
if length < 14 || length > len(buffer) {
s++
continue
}
// 分组信息解析
groupLen := gbinary.DecodeToInt(buffer[s + 7 : s + 8])
groupLen := gbinary.DecodeToInt(buffer[s + 9 : s + 10])
// checksum校验(仅对参数做校验,提高校验效率)
checksum1 := gbinary.DecodeToUint32(buffer[s + 8 + groupLen : s + 8 + groupLen + 4])
checksum2 := gtcp.Checksum(buffer[s + 8 + groupLen + 4 : s + length])
checksum1 := gbinary.DecodeToUint32(buffer[s + 10 + groupLen : s + 10 + groupLen + 4])
checksum2 := gtcp.Checksum(buffer[s + 10 + groupLen + 4 : s + length])
if checksum1 != checksum2 {
s++
continue
}
// 接收进程PID校验
if Pid() == gbinary.DecodeToInt(buffer[s + 5 : s + 7]) {
if Pid() == gbinary.DecodeToInt(buffer[s + 6 : s + 9]) {
msgs = append(msgs, &Msg {
Pid : gbinary.DecodeToInt(buffer[s + 3 : s + 5]),
Data : buffer[s + 8 + groupLen + 4 : s + length],
Group : string(buffer[s + 8 : s + 8 + groupLen]),
Pid : gbinary.DecodeToInt(buffer[s + 3 : s + 6]),
Data : buffer[s + 10 + groupLen + 4 : s + length],
Group : string(buffer[s + 10 : s + 10 + groupLen]),
})
}
s += length

View File

@ -27,16 +27,16 @@ const (
)
// 向指定gproc进程发送数据
// 数据格式:总长度(24bit)|发送进程PID(16bit)|接收进程PID(16bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
// 数据格式:总长度(24bit)|发送进程PID(24bit)|接收进程PID(24bit)|分组长度(8bit)|分组名称(变长)|校验(32bit)|参数(变长)
func Send(pid int, data []byte, group...string) error {
groupName := gPROC_COMM_DEAFULT_GRUOP_NAME
if len(group) > 0 {
groupName = group[0]
}
buffer := make([]byte, 0)
buffer = append(buffer, gbinary.EncodeByLength(3, len(groupName) + len(data) + 12)...)
buffer = append(buffer, gbinary.EncodeByLength(2, Pid())...)
buffer = append(buffer, gbinary.EncodeByLength(2, pid)...)
buffer = append(buffer, gbinary.EncodeByLength(3, len(groupName) + len(data) + 14)...)
buffer = append(buffer, gbinary.EncodeByLength(3, Pid())...)
buffer = append(buffer, gbinary.EncodeByLength(3, pid)...)
buffer = append(buffer, gbinary.EncodeByLength(1, len(groupName))...)
buffer = append(buffer, []byte(groupName)...)
buffer = append(buffer, gbinary.EncodeUint32(gtcp.Checksum(data))...)

View File

@ -13,18 +13,18 @@ import (
"gitee.com/johng/gf/g/os/grpool"
)
func increment1() {
func increment() {
for i := 0; i < 1000000; i++ {}
}
func BenchmarkGrpool_1(b *testing.B) {
for i := 0; i < b.N; i++ {
grpool.Add(increment1)
grpool.Add(increment)
}
}
func BenchmarkGoroutine_1(b *testing.B) {
for i := 0; i < b.N; i++ {
go increment1()
go increment()
}
}

View File

@ -15,10 +15,6 @@ import (
var n = 500000
func increment2() {
for i := 0; i < 1000000; i++ {}
}
func BenchmarkGrpool2(b *testing.B) {
b.N = n
for i := 0; i < b.N; i++ {

View File

@ -1,35 +0,0 @@
// Copyright 2017 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
package grpool_test
import (
"testing"
"runtime"
"fmt"
)
func increment() {
for i := 0; i < 100000; i++ {}
}
//func Test_GrpoolMemUsage(t *testing.T) {
// for i := 0; i < n; i++ {
// grpool.Add(increment)
// }
// mem := runtime.MemStats{}
// runtime.ReadMemStats(&mem)
// fmt.Println("mem usage:", mem.TotalAlloc/1024)
//}
func Test_GroroutineMemUsage(t *testing.T) {
for i := 0; i < n; i++ {
go increment()
}
mem := runtime.MemStats{}
runtime.ReadMemStats(&mem)
fmt.Println("mem usage:", mem.TotalAlloc/1024)
}

View File

@ -123,7 +123,7 @@ func (t *Time) ToZone(zone string) *Time {
t.Time = t.Time.In(l)
return t
} else {
panic(err)
//panic(err)
return nil
}
}

View File

@ -49,10 +49,12 @@ var viewObj *View
// 初始化默认的视图对象
func checkAndInitDefaultView() {
if viewObj == nil {
if gfile.SelfDir() != gfile.TempDir() {
// gfile.MainPkgPath() 用以判断是否开发环境
mainPkgPath := gfile.MainPkgPath()
if gfile.MainPkgPath() == "" {
viewObj = New(gfile.SelfDir())
} else {
viewObj = New()
viewObj = New(mainPkgPath)
}
}
}

View File

@ -99,8 +99,11 @@ func Map(i interface{}, noTagCheck...bool) map[string]interface{} {
rt := rv.Type()
name := ""
for i := 0; i < rv.NumField(); i++ {
if name = rt.Field(i).Tag.Get("json"); name == "" {
name = rt.Field(i).Name
// 检查json tag
if len(noTagCheck) == 0 || !noTagCheck[0] {
if name = rt.Field(i).Tag.Get("json"); name == "" {
name = rt.Field(i).Name
}
}
m[name] = rv.Field(i).Interface()
}

View File

@ -7,7 +7,6 @@
package gconv
import (
"fmt"
"reflect"
)
@ -77,12 +76,11 @@ func Ints(i interface{}) []int {
for _, v := range i.([]interface{}) {
array = append(array, Int(v))
}
default:
return []int{Int(i)}
}
if len(array) > 0 {
return array
}
return array
}
return []int{Int(i)}
}
// 任意类型转换为[]string类型
@ -95,71 +93,70 @@ func Strings(i interface{}) []string {
} else {
array := make([]string, 0)
switch i.(type) {
case []int:
for _, v := range i.([]int) {
array = append(array, String(v))
}
case []int8:
for _, v := range i.([]int8) {
array = append(array, String(v))
}
case []int16:
for _, v := range i.([]int16) {
array = append(array, String(v))
}
case []int32:
for _, v := range i.([]int32) {
array = append(array, String(v))
}
case []int64:
for _, v := range i.([]int64) {
array = append(array, String(v))
}
case []uint:
for _, v := range i.([]uint) {
array = append(array, String(v))
}
case []uint8:
for _, v := range i.([]uint8) {
array = append(array, String(v))
}
case []uint16:
for _, v := range i.([]uint16) {
array = append(array, String(v))
}
case []uint32:
for _, v := range i.([]uint32) {
array = append(array, String(v))
}
case []uint64:
for _, v := range i.([]uint64) {
array = append(array, String(v))
}
case []bool:
for _, v := range i.([]bool) {
array = append(array, String(v))
}
case []float32:
for _, v := range i.([]float32) {
array = append(array, String(v))
}
case []float64:
for _, v := range i.([]float64) {
array = append(array, String(v))
}
case []interface{}:
for _, v := range i.([]interface{}) {
array = append(array, String(v))
}
}
if len(array) > 0 {
return array
case []int:
for _, v := range i.([]int) {
array = append(array, String(v))
}
case []int8:
for _, v := range i.([]int8) {
array = append(array, String(v))
}
case []int16:
for _, v := range i.([]int16) {
array = append(array, String(v))
}
case []int32:
for _, v := range i.([]int32) {
array = append(array, String(v))
}
case []int64:
for _, v := range i.([]int64) {
array = append(array, String(v))
}
case []uint:
for _, v := range i.([]uint) {
array = append(array, String(v))
}
case []uint8:
for _, v := range i.([]uint8) {
array = append(array, String(v))
}
case []uint16:
for _, v := range i.([]uint16) {
array = append(array, String(v))
}
case []uint32:
for _, v := range i.([]uint32) {
array = append(array, String(v))
}
case []uint64:
for _, v := range i.([]uint64) {
array = append(array, String(v))
}
case []bool:
for _, v := range i.([]bool) {
array = append(array, String(v))
}
case []float32:
for _, v := range i.([]float32) {
array = append(array, String(v))
}
case []float64:
for _, v := range i.([]float64) {
array = append(array, String(v))
}
case []interface{}:
for _, v := range i.([]interface{}) {
array = append(array, String(v))
}
default:
return []string{String(i)}
}
return array
}
return []string{fmt.Sprintf("%v", i)}
}
// 任意类型转换为[]float64类型
// 类型转换为[]float64类型
func Floats(i interface{}) []float64 {
if i == nil {
return nil
@ -225,12 +222,11 @@ func Floats(i interface{}) []float64 {
for _, v := range i.([]interface{}) {
array = append(array, Float64(v))
}
default:
return []float64{Float64(i)}
}
if len(array) > 0 {
return array
}
return array
}
return []float64{Float64(i)}
}
// 任意类型转换为[]interface{}类型
@ -318,11 +314,10 @@ func Interfaces(i interface{}) []interface{} {
for i := 0; i < rv.NumField(); i++ {
array = append(array, rv.Field(i).Interface())
}
default:
return []interface{}{i}
}
}
if len(array) > 0 {
return array
}
return array
}
return []interface{}{i}
}

View File

@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er
structFieldValue := elem.FieldByName(name)
// 键名与对象属性匹配检测map中如果有struct不存在的属性那么不做处理直接return
if !structFieldValue.IsValid() {
//return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name))
return nil
}
// CanSet的属性必须为公开属性(首字母大写)
if !structFieldValue.CanSet() {
//return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name))
return nil
}
// 必须将value转换为struct属性的数据类型这里必须用到gconv包
@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
structFieldValue := elem.FieldByIndex([]int{index})
// 键名与对象属性匹配检测
if !structFieldValue.IsValid() {
//return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index))
return nil
}
// CanSet的属性必须为公开属性(首字母大写)
if !structFieldValue.CanSet() {
//return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index))
return nil
}
// 必须将value转换为struct属性的数据类型这里必须用到gconv包

View File

@ -14,21 +14,26 @@ import (
// 将变量i转换为time.Time类型
func Time(i interface{}, format...string) time.Time {
s := String(i)
// 优先使用用户输入日期格式进行转换
if len(format) > 0 {
t, _ := gtime.StrToTimeFormat(s, format[0])
return t.Time
}
if gstr.IsNumeric(s) {
return gtime.NewFromTimeStamp(Int64(s)).Time
} else {
t, _ := gtime.StrToTime(s)
return t.Time
}
return GTime(i, format...).Time
}
// 将变量i转换为time.Time类型
func TimeDuration(i interface{}) time.Duration {
return time.Duration(Int64(i))
}
// 将变量i转换为time.Time类型
func GTime(i interface{}, format...string) *gtime.Time {
s := String(i)
// 优先使用用户输入日期格式进行转换
if len(format) > 0 {
t, _ := gtime.StrToTimeFormat(s, format[0])
return t
}
if gstr.IsNumeric(s) {
return gtime.NewFromTimeStamp(Int64(s))
} else {
t, _ := gtime.StrToTime(s)
return t
}
}

View File

@ -9,7 +9,6 @@ package grand
import (
"crypto/rand"
"encoding/binary"
"time"
)
const (
@ -34,8 +33,8 @@ func init() {
bufferChan <- binary.LittleEndian.Uint64(buffer[i : i + 8])
i ++
}
// 充分利用缓冲区数据,字节倒序生成,随机索引递增
step = int(time.Now().UnixNano())%n
// 充分利用缓冲区数据,随机索引递增
step = int(buffer[0])%10
for i := 0; i < n - 8; {
bufferChan <- binary.BigEndian.Uint64(buffer[i : i + 8])
i += step

View File

@ -13,7 +13,7 @@ import (
"strings"
)
// 字符串替换
// 字符串替换(大小写敏感)
func Replace(origin, search, replace string, count...int) string {
n := -1
if len(count) > 0 {
@ -22,7 +22,7 @@ func Replace(origin, search, replace string, count...int) string {
return strings.Replace(origin, search, replace, n)
}
// 使用map进行字符串替换
// 使用map进行字符串替换(大小写敏感)
func ReplaceByMap(origin string, replaces map[string]string) string {
result := origin
for k, v := range replaces {

32
g/util/gtest/gtest.go Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// Package gtest provides useful test utils.
// 测试模块.
package gtest
import (
"fmt"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/util/gconv"
"os"
)
// 断言判断
func Assert(value, expect interface{}) {
if gconv.String(value) != gconv.String(expect) {
glog.Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect)
glog.Header(false).PrintBacktrace(1)
os.Exit(1)
}
}
// 提示错误并退出
func Fatal(message...interface{}) {
glog.Println(`[FATAL]`, fmt.Sprint(message...))
glog.Header(false).PrintBacktrace(1)
os.Exit(1)
}

View File

@ -4,13 +4,14 @@
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// 其他工具包
// 工具包
package gutil
import (
"fmt"
"bytes"
"encoding/json"
"os"
"reflect"
"gitee.com/johng/gf/g/util/gconv"
"runtime"
@ -43,7 +44,7 @@ func Dump(i...interface{}) {
if err := encoder.Encode(v); err == nil {
fmt.Print(buffer.String())
} else {
fmt.Errorf("%s", err.Error())
fmt.Fprintln(os.Stderr, err.Error())
}
}
//fmt.Println()

View File

@ -14,7 +14,12 @@ import (
// 检测键值对参数Map
// rules参数支持 []string / map[string]string 类型,前面一种类型支持返回校验结果顺序(具体格式参考struct tag),后一种不支持;
// rules参数中得 map[string]string 是一个2维的关联数组第一维键名为参数键名第二维为带有错误的校验规则名称值为错误信息。
func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg) *Error {
func CheckMap(params interface{}, rules interface{}, msgs...CustomMsg) *Error {
// 将参数转换为 map[string]interface{}类型
data := gconv.Map(params)
if data == nil {
return newErrorStr("invalid_params", "invalid params type: convert to map[string]interface{} failed")
}
// 真实校验规则数据结构
checkRules := make(map[string]string)
// 真实自定义错误信息数据结构
@ -73,11 +78,15 @@ func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg
value := (interface{})(nil)
// 这里的rule变量为多条校验规则不包含名字或者错误信息定义
for key, rule := range checkRules {
// 如果规则为空,那么不执行校验
if len(rule) == 0 {
continue
}
value = nil
if v, ok := params[key]; ok {
if v, ok := data[key]; ok {
value = v
}
if e := Check(value, rule, customMsgs[key], params); e != nil {
if e := Check(value, rule, customMsgs[key], data); e != nil {
_, item := e.FirstItem()
// 如果值为nil|""并且不需要require*验证时,其他验证失效
if value == nil || gconv.String(value) == "" {

View File

@ -22,7 +22,7 @@ type Error struct {
type ErrorMap map[string]map[string]string
// 创建一个校验错误对象指针
// 创建一个校验错误对象指针(校验错误)
func newError(rules []string, errors map[string]map[string]string) *Error {
return &Error {
rules : rules,
@ -30,6 +30,18 @@ func newError(rules []string, errors map[string]map[string]string) *Error {
}
}
// 创建一个校验错误对象指针(内部错误)
func newErrorStr(key, err string) *Error {
return &Error {
rules : nil,
errors : map[string]map[string]string{
"__gvalid__" : {
key: err,
},
},
}
}
// 获得规则与错误信息的map; 当校验结果为多条数据校验时返回第一条错误map(此时类似FirstItem)
func (e *Error) Map() map[string]string {
_, m := e.FirstItem()

View File

@ -0,0 +1,20 @@
package main
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/container/garray"
)
func main() {
array := garray.NewSortedStringArray(0, false)
array.Add("1")
array.Add("2")
array.Add("3")
array.Add("4")
array.Add("5")
array.Add("6")
array.Add("7")
array.Add("8")
array.Add("9")
g.Dump(array.Slice())
}

View File

@ -0,0 +1,26 @@
package main
import (
"fmt"
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/container/garray"
"strings"
)
func main() {
array := garray.NewSortedStringArray(0, false)
array.Add("/api/ctl/show")
array.Add("/api/ctl/post")
array.Add("/api/obj/rest")
array.Add("/api/handler")
array.Add("/api/obj/delete")
array.Add("/api/obj/show")
array.Add("/api/obj/my-show")
array.Add("/api/*")
array.Add("/api/ctl/rest")
array.Add("/api/ctl/my-show")
g.Dump(array.Slice())
fmt.Println(strings.Compare("/api/ctl/post", "/api/*"))
fmt.Println(strings.Compare("/api/*", "/api/ctl/my-show"))
}

View File

@ -9,7 +9,7 @@ import (
// 本文件用于gf框架的mysql数据库操作示例不作为单元测试使用
var db *gdb.Db
var db gdb.DB
// 初始化配置及创建数据库
func init () {
@ -17,7 +17,7 @@ func init () {
Host : "127.0.0.1",
Port : "3306",
User : "root",
Pass : "8692651",
Pass : "12345678",
Name : "test",
Type : "mysql",
Role : "master",

View File

@ -11,7 +11,7 @@ func main() {
Host: "127.0.0.1",
Port: "3306",
User: "root",
Pass: "123456",
Pass: "12345678",
Name: "test",
Type: "mysql",
Role: "master",

View File

@ -1,6 +1,7 @@
package main
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/database/gdb"
"fmt"
"gitee.com/johng/gf/g/encoding/gparser"
@ -11,18 +12,14 @@ func main() {
Host : "127.0.0.1",
Port : "3306",
User : "root",
Pass : "123456",
Pass : "12345678",
Name : "test",
Type : "mysql",
Role : "master",
Charset : "utf8",
})
db, err := gdb.New()
if err != nil {
panic(err)
}
one, err := db.Table("user").Where("uid=?", 1).One()
db := g.DB()
one, err := db.Table("user").Where("id=?", 1).One()
if err != nil {
panic(err)
}
@ -33,7 +30,7 @@ func main() {
// 自定义方法方法转换为json/xml
jsonContent, _ := gparser.VarToJson(one.ToMap())
fmt.Println(jsonContent)
fmt.Println(string(jsonContent))
xmlContent, _ := gparser.VarToXml(one.ToMap())
fmt.Println(xmlContent)
fmt.Println(string(xmlContent))
}

View File

@ -1,28 +1,16 @@
package main
import (
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g"
"time"
)
func main() {
gdb.AddDefaultConfigNode(gdb.ConfigNode {
Host : "127.0.0.1",
Port : "3306",
User : "root",
Pass : "12345678",
Name : "test",
Type : "mysql",
Role : "master",
Charset : "utf8",
MaxIdleConnCount : 10,
MaxOpenConnCount : 10,
MaxConnLifetime : 10,
})
db, err := gdb.New()
if err != nil {
panic(err)
}
db := g.DB()
db.SetMaxIdleConns(10)
db.SetMaxOpenConns(10)
db.SetConnMaxLifetime(10)
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)

View File

@ -2,28 +2,15 @@ package main
import (
"fmt"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g"
)
func main() {
gdb.AddDefaultConfigNode(gdb.ConfigNode {
Host : "192.168.1.11",
Port : "3306",
User : "root",
Pass : "8692651",
Name : "test",
Type : "mysql",
Role : "master",
Charset : "utf8",
})
db, err := gdb.New()
if err != nil {
panic(err)
}
db := g.DB()
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)
r, _ := db.Table("user").All()
r, _ := db.Table("test").Where("id IN (?)", []interface{}{1, 2}).All()
if r != nil {
fmt.Println(r.ToList())
}

View File

@ -0,0 +1,78 @@
package main
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/frame/gmvc"
"gitee.com/johng/gf/g/net/ghttp"
)
type Object struct {}
func (o *Object) Show(r *ghttp.Request) {
r.Response.Writeln("Object Show")
}
func (o *Object) Delete(r *ghttp.Request) {
r.Response.Writeln("Object REST Delete")
}
func (o *Object) Shut(r *ghttp.Request) {
r.Response.Writeln("Object Shut")
}
type Controller struct {
gmvc.Controller
}
func (c *Controller) Show() {
c.Response.Writeln("Controller Show")
}
func (c *Controller) Post() {
c.Response.Writeln("Controller REST Post")
}
func (c *Controller) Shut() {
c.Response.Writeln("Controller Shut")
}
func Handler(r *ghttp.Request) {
r.Response.Writeln("Handler")
}
func HookHandler(r *ghttp.Request) {
r.Response.Writeln("Hook Handler")
}
func main() {
s := g.Server()
obj := new(Object)
ctl := new(Controller)
// 分组路由方法注册
//g := s.Group("/api")
//g.ALL ("*", HookHandler, ghttp.HOOK_BEFORE_SERVE)
//g.ALL ("/handler", Handler)
//g.ALL ("/ctl", ctl)
//g.GET ("/ctl/my-show", ctl, "Show")
//g.REST("/ctl/rest", ctl)
//g.ALL ("/obj", obj)
//g.GET ("/obj/my-show", obj, "Show")
//g.REST("/obj/rest", obj)
// 分组路由批量注册
s.Group("/api").Bind("/api", []ghttp.GroupItem{
{"ALL", "/handler", Handler},
{"ALL", "/ctl", ctl},
{"GET", "/ctl/my-show", ctl, "Show"},
{"REST", "/ctl/rest", ctl},
{"ALL", "/obj", obj},
{"GET", "/obj/my-show", obj, "Show"},
{"REST", "/obj/rest", obj},
{"ALL", "*", HookHandler, ghttp.HOOK_BEFORE_SERVE},
})
s.SetPort(8199)
s.Run()
}

View File

@ -5,6 +5,7 @@ import (
"bytes"
"gitee.com/johng/gf/g/net/gtcp"
"gitee.com/johng/gf/g/util/gconv"
"os"
)
func main() {
@ -39,7 +40,7 @@ func main() {
break
}
if err != nil {
fmt.Errorf("ERROR: %s\n", err.Error())
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err.Error())
break
}
}

View File

@ -3,6 +3,7 @@ package main
import (
"gitee.com/johng/gf/g/net/gtcp"
"fmt"
"os"
)
func main() {
@ -11,6 +12,6 @@ func main() {
fmt.Println(string(data))
}
if err != nil {
fmt.Errorf("ERROR: %s\n", err.Error())
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err.Error())
}
}

19
geg/os/glog/glog_pool.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gtime"
"time"
)
// 测试删除日志文件是否会重建日志文件
func main() {
path := "/Users/john/Temp/test"
glog.SetPath(path)
for {
glog.Println(gtime.Now().String())
time.Sleep(time.Second)
}
}

View File

@ -1,19 +1,33 @@
package main
import (
"fmt"
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/util/gregex"
)
type Registry struct {
Method string
Uri string
Handler interface{}
Object interface{}
}
func BindGroup(group string, routers [][]interface{}) {
}
type User struct { }
type Order struct { }
type Product struct { }
func HookFunc() {
}
func main() {
//s := `1544180795 -- s_has_sess -- 41570504 -decryptSess- 41570504__iuVycRYg9qE3y7CsSgGZH1K2nxTdjPZN4fXot65zHIEmULO0Ow6LweJp5raWl8Ft -postSess- eyJpdiI6IkFwSWZ3eXFMcGxBZE5JcWF4aXh0M3c9PSIsInZhbHVlIjoiV3ZLeGduMnRoRkFZdmxHTzM5ZzdyU1JHWDMycmZlRERvNnFkaUR0SitlRjBrZnlYR1JvS2puTGZNUThSeFR0bWtlT3pza0l0elFqRk5mdXF6XC9FWWpWZnljVjdJbHd3dTRybEhldHZHTk5DQ015dlpYNHljNmxKMWJTRUVpY0E4IiwibWFjIjoiOTkxMzIxOTRhMGUxZWZiODM4NWZjNDZjYmVhNWY2NjhlZDZkNmVlNjY1MTE2N2VhZDAzYzY4NDJmZGFkMjY5YyJ9 -- 0 -- B8105CF2-1588-4753-9F86-9B8C36EB1842 -- iPhone 7 -- 12.1 -- 6.8.7 -- i -- 10.111.153.5 -- medlinker -- service -- unknown
//`
// s := `[08-Dec-2018 13:35:03 Asia/Shanghai] Medlinker\Services\Message\MessageService|updateUserInfo|用户头像 URI 不能为空 in /var/www/med-d2d/app/Services/Message/RongCloudService.php on line 851`
s := `[2018-12-01 13:35:03 Asia/Shanghai] 1544180795 Medlinker\Services\Message\MessageService|updateUserInfo|用户头像 URI 不能为空 in /var/www/med-d2d/app/Services/Message/RongCloudService.php on line 851`
//m, e := gregex.MatchString(`/var/log/medlinker/[\w\-\_]+/(.+?)/{0,1}[\d\-\_]*\.log`, `/var/log/medlinker/med-questionnaire/nginx/error/access-20181206.log`)
//m, e := gregex.MatchString(`/var/log/medlinker/[\w\-\_]+/(.+?)/{0,1}[\d\-\_]*\.log`, `/var/log/medlinker/med-questionnaire/storagelogs/events/sqlLog/2018-12-06.log`)
m, e := gregex.MatchString(`(.*?((\d{4}[-/\.]\d{2}[-/\.]\d{2}|\d{1,2}[-/\.][A-Za-z]{3,}[-/\.]\d{4})[:\sT-]*\d{0,2}:{0,1}\d{0,2}:{0,1}\d{0,2}\.{0,1}\d{0,9}[\sZ]{0,1}[\+-]{0,1}[:\d]*|\d{10}).+)`, s)
fmt.Println(e)
g.Dump(m)
user := new(User)
BindGroup("/api", [][]interface{} {
{"ALL", "/*", HookFunc, "BeforeServe"},
{"ALL", "/order", new(Order)},
{"REST", "/product", new(Product)},
{"GET", "/user/register", "Register", user},
{"GET", "/user/reset-pass", "ResetPassword", user},
{"POST", "/user/reset-pass", "ResetPassword", user},
{"POST", "/user/login", "Login", user},
})
}

View File

@ -1,41 +0,0 @@
package test
import (
"strings"
"testing"
)
var s = "/name/john///."
var c = "./"
func t1(s string) string {
if len(s) == 0 {
return s
}
for _, cut := range c {
for s[len(s) - 1] == uint8(cut) {
s = s[:len(s) - 1]
if len(s) == 0 {
return s
}
}
}
return s
}
func t2(s string) string {
return strings.TrimRight(s, c)
}
func Benchmark_t1(b *testing.B) {
for i := 0; i < b.N; i++ {
t1(s)
}
}
func Benchmark_t2(b *testing.B) {
for i := 0; i < b.N; i++ {
t2(s)
}
}

View File

@ -1,35 +1,21 @@
package main
import (
"fmt"
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/util/gvalid"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g/os/glog"
)
type User struct {
Uid int `gvalid:"uid @integer|min:1"`
Name string `gvalid:"name @required|length:6,30#请输入用户名称|用户名称长度非法"`
Pass1 string `gvalid:"password1@required|password3"`
Pass2 string `gvalid:"password2@required|password3|same:password1#||两次密码不一致,请重新输入"`
}
func main() {
user := &User{
Name : "john",
Pass1: "Abc123!@#",
Pass2: "123",
}
gdb.AddDefaultConfigNode(gdb.ConfigNode{
Type : "mysql",
Linkinfo : "root:12345678@tcp(127.0.0.1:3306)/test",
})
// 使用结构体定义的校验规则和错误提示进行校验
g.Dump(gvalid.CheckStruct(user, nil).Maps())
// 自定义校验规则和错误提示,对定义的特定校验规则和错误提示进行覆盖
rules := map[string]string {
"Uid" : "required",
if r, err := g.Database().GetOne("select now() as time"); err != nil {
glog.Error("Mysql Init Select Now : %v", err)
} else {
fmt.Println(r.ToMap())
}
msgs := map[string]interface{} {
"Pass2" : map[string]string {
"password3" : "名称不能为空",
},
}
g.Dump(gvalid.CheckStruct(user, rules, msgs).Maps())
}
}

View File

@ -6,7 +6,7 @@ import (
)
func main() {
for i := 0; i < 10; i++ {
for i := 0; i < 100; i++ {
fmt.Println(grand.Rand(0, 99999))
}
}

1
go.mod
View File

@ -1,2 +1 @@
module gitee.com/johng/gf

View File

@ -0,0 +1,30 @@
language: go
go:
- 1.9.x
- 1.10.x
- 1.11.x
os:
- linux
- osx
matrix:
include:
name: "Go 1.11.x CentOS 32bits"
language: go
go: 1.11.x
os: linux
services:
- docker
script:
# Please update Go version in travis_test_32 as needed
- "docker run -i -v \"${PWD}:/zstd\" toopher/centos-i386:centos6 /bin/bash -c \"linux32 --32bit i386 /zstd/travis_test_32.sh\""
install:
- "wget https://github.com/DataDog/zstd/files/2246767/mr.zip"
- "unzip mr.zip"
script:
- "go build"
- "PAYLOAD=`pwd`/mr go test -v"
- "PAYLOAD=`pwd`/mr go test -bench ."

View File

@ -0,0 +1,27 @@
Simplified BSD License
Copyright (c) 2016, Datadog <info@datadoghq.com>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,120 @@
# Zstd Go Wrapper
[C Zstd Homepage](https://github.com/Cyan4973/zstd)
The current headers and C files are from *v1.3.4* (Commit
[2555975](https://github.com/facebook/zstd/releases/tag/v1.3.4)).
## Usage
There are two main APIs:
* simple Compress/Decompress
* streaming API (io.Reader/io.Writer)
The compress/decompress APIs mirror that of lz4, while the streaming API was
designed to be a drop-in replacement for zlib.
### Simple `Compress/Decompress`
```go
// Compress compresses the byte array given in src and writes it to dst.
// If you already have a buffer allocated, you can pass it to prevent allocation
// If not, you can pass nil as dst.
// If the buffer is too small, it will be reallocated, resized, and returned bu the function
// If dst is nil, this will allocate the worst case size (CompressBound(src))
Compress(dst, src []byte) ([]byte, error)
```
```go
// CompressLevel is the same as Compress but you can pass another compression level
CompressLevel(dst, src []byte, level int) ([]byte, error)
```
```go
// Decompress will decompress your payload into dst.
// If you already have a buffer allocated, you can pass it to prevent allocation
// If not, you can pass nil as dst (allocates a 4*src size as default).
// If the buffer is too small, it will retry 3 times by doubling the dst size
// After max retries, it will switch to the slower stream API to be sure to be able
// to decompress. Currently switches if compression ratio > 4*2**3=32.
Decompress(dst, src []byte) ([]byte, error)
```
### Stream API
```go
// NewWriter creates a new object that can optionally be initialized with
// a precomputed dictionary. If dict is nil, compress without a dictionary.
// The dictionary array should not be changed during the use of this object.
// You MUST CALL Close() to write the last bytes of a zstd stream and free C objects.
NewWriter(w io.Writer) *Writer
NewWriterLevel(w io.Writer, level int) *Writer
NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
// Write compresses the input data and write it to the underlying writer
(w *Writer) Write(p []byte) (int, error)
// Close flushes the buffer and frees C zstd objects
(w *Writer) Close() error
```
```go
// NewReader returns a new io.ReadCloser that will decompress data from the
// underlying reader. If a dictionary is provided to NewReaderDict, it must
// not be modified until Close is called. It is the caller's responsibility
// to call Close, which frees up C objects.
NewReader(r io.Reader) io.ReadCloser
NewReaderDict(r io.Reader, dict []byte) io.ReadCloser
```
### Benchmarks (benchmarked with v0.5.0)
The author of Zstd also wrote lz4. Zstd is intended to occupy a speed/ratio
level similar to what zlib currently provides. In our tests, the can always
be made to be better than zlib by chosing an appropriate level while still
keeping compression and decompression time faster than zlib.
You can run the benchmarks against your own payloads by using the Go benchmarks tool.
Just export your payload filepath as the `PAYLOAD` environment variable and run the benchmarks:
```go
go test -bench .
```
Compression of a 7Mb pdf zstd (this wrapper) vs [czlib](https://github.com/DataDog/czlib):
```
BenchmarkCompression 5 221056624 ns/op 67.34 MB/s
BenchmarkDecompression 100 18370416 ns/op 810.32 MB/s
BenchmarkFzlibCompress 2 610156603 ns/op 24.40 MB/s
BenchmarkFzlibDecompress 20 81195246 ns/op 183.33 MB/s
```
Ratio is also better by a margin of ~20%.
Compression speed is always better than zlib on all the payloads we tested;
However, [czlib](https://github.com/DataDog/czlib) has optimisations that make it
faster at decompressiong small payloads:
```
Testing with size: 11... czlib: 8.97 MB/s, zstd: 3.26 MB/s
Testing with size: 27... czlib: 23.3 MB/s, zstd: 8.22 MB/s
Testing with size: 62... czlib: 31.6 MB/s, zstd: 19.49 MB/s
Testing with size: 141... czlib: 74.54 MB/s, zstd: 42.55 MB/s
Testing with size: 323... czlib: 155.14 MB/s, zstd: 99.39 MB/s
Testing with size: 739... czlib: 235.9 MB/s, zstd: 216.45 MB/s
Testing with size: 1689... czlib: 116.45 MB/s, zstd: 345.64 MB/s
Testing with size: 3858... czlib: 176.39 MB/s, zstd: 617.56 MB/s
Testing with size: 8811... czlib: 254.11 MB/s, zstd: 824.34 MB/s
Testing with size: 20121... czlib: 197.43 MB/s, zstd: 1339.11 MB/s
Testing with size: 45951... czlib: 201.62 MB/s, zstd: 1951.57 MB/s
```
zstd starts to shine with payloads > 1KB
### Stability - Current state: STABLE
The C library seems to be pretty stable and according to the author has been tested and fuzzed.
For the Go wrapper, the test cover most usual cases and we have succesfully tested it on all staging and prod data.

View File

@ -0,0 +1,30 @@
BSD License
For Zstandard software
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,471 @@
/* ******************************************************************
bitstream
Part of FSE library
header file (to include)
Copyright (C) 2013-2017, Yann Collet.
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
You can contact the author at :
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
****************************************************************** */
#ifndef BITSTREAM_H_MODULE
#define BITSTREAM_H_MODULE
#if defined (__cplusplus)
extern "C" {
#endif
/*
* This API consists of small unitary functions, which must be inlined for best performance.
* Since link-time-optimization is not available for all compilers,
* these functions are defined into a .h to be included.
*/
/*-****************************************
* Dependencies
******************************************/
#include "mem.h" /* unaligned access routines */
#include "error_private.h" /* error codes and messages */
/*-*************************************
* Debug
***************************************/
#if defined(BIT_DEBUG) && (BIT_DEBUG>=1)
# include <assert.h>
#else
# ifndef assert
# define assert(condition) ((void)0)
# endif
#endif
/*=========================================
* Target specific
=========================================*/
#if defined(__BMI__) && defined(__GNUC__)
# include <immintrin.h> /* support for bextr (experimental) */
#endif
#define STREAM_ACCUMULATOR_MIN_32 25
#define STREAM_ACCUMULATOR_MIN_64 57
#define STREAM_ACCUMULATOR_MIN ((U32)(MEM_32bits() ? STREAM_ACCUMULATOR_MIN_32 : STREAM_ACCUMULATOR_MIN_64))
/*-******************************************
* bitStream encoding API (write forward)
********************************************/
/* bitStream can mix input from multiple sources.
* A critical property of these streams is that they encode and decode in **reverse** direction.
* So the first bit sequence you add will be the last to be read, like a LIFO stack.
*/
typedef struct
{
size_t bitContainer;
unsigned bitPos;
char* startPtr;
char* ptr;
char* endPtr;
} BIT_CStream_t;
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, void* dstBuffer, size_t dstCapacity);
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC);
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC);
/* Start with initCStream, providing the size of buffer to write into.
* bitStream will never write outside of this buffer.
* `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code.
*
* bits are first added to a local register.
* Local register is size_t, hence 64-bits on 64-bits systems, or 32-bits on 32-bits systems.
* Writing data into memory is an explicit operation, performed by the flushBits function.
* Hence keep track how many bits are potentially stored into local register to avoid register overflow.
* After a flushBits, a maximum of 7 bits might still be stored into local register.
*
* Avoid storing elements of more than 24 bits if you want compatibility with 32-bits bitstream readers.
*
* Last operation is to close the bitStream.
* The function returns the final size of CStream in bytes.
* If data couldn't fit into `dstBuffer`, it will return a 0 ( == not storable)
*/
/*-********************************************
* bitStream decoding API (read backward)
**********************************************/
typedef struct
{
size_t bitContainer;
unsigned bitsConsumed;
const char* ptr;
const char* start;
const char* limitPtr;
} BIT_DStream_t;
typedef enum { BIT_DStream_unfinished = 0,
BIT_DStream_endOfBuffer = 1,
BIT_DStream_completed = 2,
BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */
/* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize);
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits);
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD);
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD);
/* Start by invoking BIT_initDStream().
* A chunk of the bitStream is then stored into a local register.
* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t).
* You can then retrieve bitFields stored into the local register, **in reverse order**.
* Local register is explicitly reloaded from memory by the BIT_reloadDStream() method.
* A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished.
* Otherwise, it can be less than that, so proceed accordingly.
* Checking if DStream has reached its end can be performed with BIT_endOfDStream().
*/
/*-****************************************
* unsafe API
******************************************/
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
/* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC);
/* unsafe version; does not check buffer overflow */
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits);
/* faster, but works only if nbBits >= 1 */
/*-**************************************************************
* Internal functions
****************************************************************/
MEM_STATIC unsigned BIT_highbit32 (U32 val)
{
assert(val != 0);
{
# if defined(_MSC_VER) /* Visual */
unsigned long r=0;
_BitScanReverse ( &r, val );
return (unsigned) r;
# elif defined(__GNUC__) && (__GNUC__ >= 3) /* Use GCC Intrinsic */
return 31 - __builtin_clz (val);
# else /* Software version */
static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29,
11, 14, 16, 18, 22, 25, 3, 30,
8, 12, 20, 28, 15, 17, 24, 7,
19, 27, 23, 6, 26, 5, 4, 31 };
U32 v = val;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27];
# endif
}
}
/*===== Local Constants =====*/
static const unsigned BIT_mask[] = {
0, 1, 3, 7, 0xF, 0x1F,
0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF,
0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF,
0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF,
0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF,
0x3FFFFFFF, 0x7FFFFFFF}; /* up to 31 bits */
#define BIT_MASK_SIZE (sizeof(BIT_mask) / sizeof(BIT_mask[0]))
/*-**************************************************************
* bitStream encoding
****************************************************************/
/*! BIT_initCStream() :
* `dstCapacity` must be > sizeof(size_t)
* @return : 0 if success,
* otherwise an error code (can be tested using ERR_isError()) */
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC,
void* startPtr, size_t dstCapacity)
{
bitC->bitContainer = 0;
bitC->bitPos = 0;
bitC->startPtr = (char*)startPtr;
bitC->ptr = bitC->startPtr;
bitC->endPtr = bitC->startPtr + dstCapacity - sizeof(bitC->bitContainer);
if (dstCapacity <= sizeof(bitC->bitContainer)) return ERROR(dstSize_tooSmall);
return 0;
}
/*! BIT_addBits() :
* can add up to 31 bits into `bitC`.
* Note : does not check for register overflow ! */
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC,
size_t value, unsigned nbBits)
{
MEM_STATIC_ASSERT(BIT_MASK_SIZE == 32);
assert(nbBits < BIT_MASK_SIZE);
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos;
bitC->bitPos += nbBits;
}
/*! BIT_addBitsFast() :
* works only if `value` is _clean_, meaning all high bits above nbBits are 0 */
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC,
size_t value, unsigned nbBits)
{
assert((value>>nbBits) == 0);
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
bitC->bitContainer |= value << bitC->bitPos;
bitC->bitPos += nbBits;
}
/*! BIT_flushBitsFast() :
* assumption : bitContainer has not overflowed
* unsafe version; does not check buffer overflow */
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC)
{
size_t const nbBytes = bitC->bitPos >> 3;
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
bitC->ptr += nbBytes;
assert(bitC->ptr <= bitC->endPtr);
bitC->bitPos &= 7;
bitC->bitContainer >>= nbBytes*8;
}
/*! BIT_flushBits() :
* assumption : bitContainer has not overflowed
* safe version; check for buffer overflow, and prevents it.
* note : does not signal buffer overflow.
* overflow will be revealed later on using BIT_closeCStream() */
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC)
{
size_t const nbBytes = bitC->bitPos >> 3;
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
bitC->ptr += nbBytes;
if (bitC->ptr > bitC->endPtr) bitC->ptr = bitC->endPtr;
bitC->bitPos &= 7;
bitC->bitContainer >>= nbBytes*8;
}
/*! BIT_closeCStream() :
* @return : size of CStream, in bytes,
* or 0 if it could not fit into dstBuffer */
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC)
{
BIT_addBitsFast(bitC, 1, 1); /* endMark */
BIT_flushBits(bitC);
if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */
return (bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0);
}
/*-********************************************************
* bitStream decoding
**********************************************************/
/*! BIT_initDStream() :
* Initialize a BIT_DStream_t.
* `bitD` : a pointer to an already allocated BIT_DStream_t structure.
* `srcSize` must be the *exact* size of the bitStream, in bytes.
* @return : size of stream (== srcSize), or an errorCode if a problem is detected
*/
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize)
{
if (srcSize < 1) { memset(bitD, 0, sizeof(*bitD)); return ERROR(srcSize_wrong); }
bitD->start = (const char*)srcBuffer;
bitD->limitPtr = bitD->start + sizeof(bitD->bitContainer);
if (srcSize >= sizeof(bitD->bitContainer)) { /* normal case */
bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer);
bitD->bitContainer = MEM_readLEST(bitD->ptr);
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */
if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ }
} else {
bitD->ptr = bitD->start;
bitD->bitContainer = *(const BYTE*)(bitD->start);
switch(srcSize)
{
case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16);
/* fall-through */
case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24);
/* fall-through */
case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32);
/* fall-through */
case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24;
/* fall-through */
case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16;
/* fall-through */
case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8;
/* fall-through */
default: break;
}
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0;
if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */
}
bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8;
}
return srcSize;
}
MEM_STATIC size_t BIT_getUpperBits(size_t bitContainer, U32 const start)
{
return bitContainer >> start;
}
MEM_STATIC size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits)
{
#if defined(__BMI__) && defined(__GNUC__) && __GNUC__*1000+__GNUC_MINOR__ >= 4008 /* experimental */
# if defined(__x86_64__)
if (sizeof(bitContainer)==8)
return _bextr_u64(bitContainer, start, nbBits);
else
# endif
return _bextr_u32(bitContainer, start, nbBits);
#else
assert(nbBits < BIT_MASK_SIZE);
return (bitContainer >> start) & BIT_mask[nbBits];
#endif
}
MEM_STATIC size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits)
{
assert(nbBits < BIT_MASK_SIZE);
return bitContainer & BIT_mask[nbBits];
}
/*! BIT_lookBits() :
* Provides next n bits from local register.
* local register is not modified.
* On 32-bits, maxNbBits==24.
* On 64-bits, maxNbBits==56.
* @return : value extracted */
MEM_STATIC size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits)
{
#if defined(__BMI__) && defined(__GNUC__) /* experimental; fails if bitD->bitsConsumed + nbBits > sizeof(bitD->bitContainer)*8 */
return BIT_getMiddleBits(bitD->bitContainer, (sizeof(bitD->bitContainer)*8) - bitD->bitsConsumed - nbBits, nbBits);
#else
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
return ((bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> 1) >> ((regMask-nbBits) & regMask);
#endif
}
/*! BIT_lookBitsFast() :
* unsafe version; only works if nbBits >= 1 */
MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits)
{
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
assert(nbBits >= 1);
return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask);
}
MEM_STATIC void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits)
{
bitD->bitsConsumed += nbBits;
}
/*! BIT_readBits() :
* Read (consume) next n bits from local register and update.
* Pay attention to not read more than nbBits contained into local register.
* @return : extracted value. */
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, U32 nbBits)
{
size_t const value = BIT_lookBits(bitD, nbBits);
BIT_skipBits(bitD, nbBits);
return value;
}
/*! BIT_readBitsFast() :
* unsafe version; only works only if nbBits >= 1 */
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, U32 nbBits)
{
size_t const value = BIT_lookBitsFast(bitD, nbBits);
assert(nbBits >= 1);
BIT_skipBits(bitD, nbBits);
return value;
}
/*! BIT_reloadDStream() :
* Refill `bitD` from buffer previously set in BIT_initDStream() .
* This function is safe, it guarantees it will not read beyond src buffer.
* @return : status of `BIT_DStream_t` internal register.
* when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD)
{
if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */
return BIT_DStream_overflow;
if (bitD->ptr >= bitD->limitPtr) {
bitD->ptr -= bitD->bitsConsumed >> 3;
bitD->bitsConsumed &= 7;
bitD->bitContainer = MEM_readLEST(bitD->ptr);
return BIT_DStream_unfinished;
}
if (bitD->ptr == bitD->start) {
if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer;
return BIT_DStream_completed;
}
/* start < ptr < limitPtr */
{ U32 nbBytes = bitD->bitsConsumed >> 3;
BIT_DStream_status result = BIT_DStream_unfinished;
if (bitD->ptr - nbBytes < bitD->start) {
nbBytes = (U32)(bitD->ptr - bitD->start); /* ptr > start */
result = BIT_DStream_endOfBuffer;
}
bitD->ptr -= nbBytes;
bitD->bitsConsumed -= nbBytes*8;
bitD->bitContainer = MEM_readLEST(bitD->ptr); /* reminder : srcSize > sizeof(bitD->bitContainer), otherwise bitD->ptr == bitD->start */
return result;
}
}
/*! BIT_endOfDStream() :
* @return : 1 if DStream has _exactly_ reached its end (all bits consumed).
*/
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* DStream)
{
return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer)*8));
}
#if defined (__cplusplus)
}
#endif
#endif /* BITSTREAM_H_MODULE */

View File

@ -0,0 +1,111 @@
/*
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/
#ifndef ZSTD_COMPILER_H
#define ZSTD_COMPILER_H
/*-*******************************************************
* Compiler specifics
*********************************************************/
/* force inlining */
#if defined (__GNUC__) || defined(__cplusplus) || defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */
# define INLINE_KEYWORD inline
#else
# define INLINE_KEYWORD
#endif
#if defined(__GNUC__)
# define FORCE_INLINE_ATTR __attribute__((always_inline))
#elif defined(_MSC_VER)
# define FORCE_INLINE_ATTR __forceinline
#else
# define FORCE_INLINE_ATTR
#endif
/**
* FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant
* parameters. They must be inlined for the compiler to elimininate the constant
* branches.
*/
#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR
/**
* HINT_INLINE is used to help the compiler generate better code. It is *not*
* used for "templates", so it can be tweaked based on the compilers
* performance.
*
* gcc-4.8 and gcc-4.9 have been shown to benefit from leaving off the
* always_inline attribute.
*
* clang up to 5.0.0 (trunk) benefit tremendously from the always_inline
* attribute.
*/
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5
# define HINT_INLINE static INLINE_KEYWORD
#else
# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR
#endif
/* force no inlining */
#ifdef _MSC_VER
# define FORCE_NOINLINE static __declspec(noinline)
#else
# ifdef __GNUC__
# define FORCE_NOINLINE static __attribute__((__noinline__))
# else
# define FORCE_NOINLINE static
# endif
#endif
/* target attribute */
#ifndef __has_attribute
#define __has_attribute(x) 0 /* Compatibility with non-clang compilers. */
#endif
#if defined(__GNUC__)
# define TARGET_ATTRIBUTE(target) __attribute__((__target__(target)))
#else
# define TARGET_ATTRIBUTE(target)
#endif
/* Enable runtime BMI2 dispatch based on the CPU.
* Enabled for clang & gcc >=4.8 on x86 when BMI2 isn't enabled by default.
*/
#ifndef DYNAMIC_BMI2
#if (defined(__clang__) && __has_attribute(__target__)) \
|| (defined(__GNUC__) \
&& (__GNUC__ >= 5 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) \
&& (defined(__x86_64__) || defined(_M_X86)) \
&& !defined(__BMI2__)
# define DYNAMIC_BMI2 1
#else
# define DYNAMIC_BMI2 0
#endif
#endif
/* prefetch */
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_I86)) /* _mm_prefetch() is not defined outside of x86/x64 */
# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
# define PREFETCH(ptr) _mm_prefetch((const char*)ptr, _MM_HINT_T0)
#elif defined(__GNUC__)
# define PREFETCH(ptr) __builtin_prefetch(ptr, 0, 0)
#else
# define PREFETCH(ptr) /* disabled */
#endif
/* disable warnings */
#ifdef _MSC_VER /* Visual Studio */
# include <intrin.h> /* For Visual 2005 */
# pragma warning(disable : 4100) /* disable: C4100: unreferenced formal parameter */
# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
# pragma warning(disable : 4204) /* disable: C4204: non-constant aggregate initializer */
# pragma warning(disable : 4214) /* disable: C4214: non-int bitfields */
# pragma warning(disable : 4324) /* disable: C4324: padded structure */
#endif
#endif /* ZSTD_COMPILER_H */

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,216 @@
/*
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/
#ifndef ZSTD_COMMON_CPU_H
#define ZSTD_COMMON_CPU_H
/**
* Implementation taken from folly/CpuId.h
* https://github.com/facebook/folly/blob/master/folly/CpuId.h
*/
#include <string.h>
#include "mem.h"
#ifdef _MSC_VER
#include <intrin.h>
#endif
typedef struct {
U32 f1c;
U32 f1d;
U32 f7b;
U32 f7c;
} ZSTD_cpuid_t;
MEM_STATIC ZSTD_cpuid_t ZSTD_cpuid(void) {
U32 f1c = 0;
U32 f1d = 0;
U32 f7b = 0;
U32 f7c = 0;
#ifdef _MSC_VER
int reg[4];
__cpuid((int*)reg, 0);
{
int const n = reg[0];
if (n >= 1) {
__cpuid((int*)reg, 1);
f1c = (U32)reg[2];
f1d = (U32)reg[3];
}
if (n >= 7) {
__cpuidex((int*)reg, 7, 0);
f7b = (U32)reg[1];
f7c = (U32)reg[2];
}
}
#elif defined(__i386__) && defined(__PIC__) && !defined(__clang__) && defined(__GNUC__)
/* The following block like the normal cpuid branch below, but gcc
* reserves ebx for use of its pic register so we must specially
* handle the save and restore to avoid clobbering the register
*/
U32 n;
__asm__(
"pushl %%ebx\n\t"
"cpuid\n\t"
"popl %%ebx\n\t"
: "=a"(n)
: "a"(0)
: "ecx", "edx");
if (n >= 1) {
U32 f1a;
__asm__(
"pushl %%ebx\n\t"
"cpuid\n\t"
"popl %%ebx\n\t"
: "=a"(f1a), "=c"(f1c), "=d"(f1d)
: "a"(1)
:);
}
if (n >= 7) {
__asm__(
"pushl %%ebx\n\t"
"cpuid\n\t"
"movl %%ebx, %%eax\n\r"
"popl %%ebx"
: "=a"(f7b), "=c"(f7c)
: "a"(7), "c"(0)
: "edx");
}
#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386__)
U32 n;
__asm__("cpuid" : "=a"(n) : "a"(0) : "ebx", "ecx", "edx");
if (n >= 1) {
U32 f1a;
__asm__("cpuid" : "=a"(f1a), "=c"(f1c), "=d"(f1d) : "a"(1) : "ebx");
}
if (n >= 7) {
U32 f7a;
__asm__("cpuid"
: "=a"(f7a), "=b"(f7b), "=c"(f7c)
: "a"(7), "c"(0)
: "edx");
}
#endif
{
ZSTD_cpuid_t cpuid;
cpuid.f1c = f1c;
cpuid.f1d = f1d;
cpuid.f7b = f7b;
cpuid.f7c = f7c;
return cpuid;
}
}
#define X(name, r, bit) \
MEM_STATIC int ZSTD_cpuid_##name(ZSTD_cpuid_t const cpuid) { \
return ((cpuid.r) & (1U << bit)) != 0; \
}
/* cpuid(1): Processor Info and Feature Bits. */
#define C(name, bit) X(name, f1c, bit)
C(sse3, 0)
C(pclmuldq, 1)
C(dtes64, 2)
C(monitor, 3)
C(dscpl, 4)
C(vmx, 5)
C(smx, 6)
C(eist, 7)
C(tm2, 8)
C(ssse3, 9)
C(cnxtid, 10)
C(fma, 12)
C(cx16, 13)
C(xtpr, 14)
C(pdcm, 15)
C(pcid, 17)
C(dca, 18)
C(sse41, 19)
C(sse42, 20)
C(x2apic, 21)
C(movbe, 22)
C(popcnt, 23)
C(tscdeadline, 24)
C(aes, 25)
C(xsave, 26)
C(osxsave, 27)
C(avx, 28)
C(f16c, 29)
C(rdrand, 30)
#undef C
#define D(name, bit) X(name, f1d, bit)
D(fpu, 0)
D(vme, 1)
D(de, 2)
D(pse, 3)
D(tsc, 4)
D(msr, 5)
D(pae, 6)
D(mce, 7)
D(cx8, 8)
D(apic, 9)
D(sep, 11)
D(mtrr, 12)
D(pge, 13)
D(mca, 14)
D(cmov, 15)
D(pat, 16)
D(pse36, 17)
D(psn, 18)
D(clfsh, 19)
D(ds, 21)
D(acpi, 22)
D(mmx, 23)
D(fxsr, 24)
D(sse, 25)
D(sse2, 26)
D(ss, 27)
D(htt, 28)
D(tm, 29)
D(pbe, 31)
#undef D
/* cpuid(7): Extended Features. */
#define B(name, bit) X(name, f7b, bit)
B(bmi1, 3)
B(hle, 4)
B(avx2, 5)
B(smep, 7)
B(bmi2, 8)
B(erms, 9)
B(invpcid, 10)
B(rtm, 11)
B(mpx, 14)
B(avx512f, 16)
B(avx512dq, 17)
B(rdseed, 18)
B(adx, 19)
B(smap, 20)
B(avx512ifma, 21)
B(pcommit, 22)
B(clflushopt, 23)
B(clwb, 24)
B(avx512pf, 26)
B(avx512er, 27)
B(avx512cd, 28)
B(sha, 29)
B(avx512bw, 30)
B(avx512vl, 31)
#undef B
#define C(name, bit) X(name, f7c, bit)
C(prefetchwt1, 0)
C(avx512vbmi, 1)
#undef C
#undef X
#endif /* ZSTD_COMMON_CPU_H */

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
/*
* divsufsort.h for libdivsufsort-lite
* Copyright (c) 2003-2008 Yuta Mori All Rights Reserved.
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef _DIVSUFSORT_H
#define _DIVSUFSORT_H 1
#ifdef __cplusplus
extern "C" {
#endif /* __cplusplus */
/*- Prototypes -*/
/**
* Constructs the suffix array of a given string.
* @param T [0..n-1] The input string.
* @param SA [0..n-1] The output array of suffixes.
* @param n The length of the given string.
* @param openMP enables OpenMP optimization.
* @return 0 if no error occurred, -1 or -2 otherwise.
*/
int
divsufsort(const unsigned char *T, int *SA, int n, int openMP);
/**
* Constructs the burrows-wheeler transformed string of a given string.
* @param T [0..n-1] The input string.
* @param U [0..n-1] The output string. (can be T)
* @param A [0..n-1] The temporary array. (can be NULL)
* @param n The length of the given string.
* @param num_indexes The length of secondary indexes array. (can be NULL)
* @param indexes The secondary indexes array. (can be NULL)
* @param openMP enables OpenMP optimization.
* @return The primary index if no error occurred, -1 or -2 otherwise.
*/
int
divbwt(const unsigned char *T, unsigned char *U, int *A, int n, unsigned char * num_indexes, int * indexes, int openMP);
#ifdef __cplusplus
} /* extern "C" */
#endif /* __cplusplus */
#endif /* _DIVSUFSORT_H */

View File

@ -0,0 +1,221 @@
/*
Common functions of New Generation Entropy library
Copyright (C) 2016, Yann Collet.
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
You can contact the author at :
- FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy
- Public forum : https://groups.google.com/forum/#!forum/lz4c
*************************************************************************** */
/* *************************************
* Dependencies
***************************************/
#include "mem.h"
#include "error_private.h" /* ERR_*, ERROR */
#define FSE_STATIC_LINKING_ONLY /* FSE_MIN_TABLELOG */
#include "fse.h"
#define HUF_STATIC_LINKING_ONLY /* HUF_TABLELOG_ABSOLUTEMAX */
#include "huf.h"
/*=== Version ===*/
unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
/*=== Error Management ===*/
unsigned FSE_isError(size_t code) { return ERR_isError(code); }
const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
unsigned HUF_isError(size_t code) { return ERR_isError(code); }
const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
/*-**************************************************************
* FSE NCount encoding-decoding
****************************************************************/
size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
const void* headerBuffer, size_t hbSize)
{
const BYTE* const istart = (const BYTE*) headerBuffer;
const BYTE* const iend = istart + hbSize;
const BYTE* ip = istart;
int nbBits;
int remaining;
int threshold;
U32 bitStream;
int bitCount;
unsigned charnum = 0;
int previous0 = 0;
if (hbSize < 4) return ERROR(srcSize_wrong);
bitStream = MEM_readLE32(ip);
nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG; /* extract tableLog */
if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
bitStream >>= 4;
bitCount = 4;
*tableLogPtr = nbBits;
remaining = (1<<nbBits)+1;
threshold = 1<<nbBits;
nbBits++;
while ((remaining>1) & (charnum<=*maxSVPtr)) {
if (previous0) {
unsigned n0 = charnum;
while ((bitStream & 0xFFFF) == 0xFFFF) {
n0 += 24;
if (ip < iend-5) {
ip += 2;
bitStream = MEM_readLE32(ip) >> bitCount;
} else {
bitStream >>= 16;
bitCount += 16;
} }
while ((bitStream & 3) == 3) {
n0 += 3;
bitStream >>= 2;
bitCount += 2;
}
n0 += bitStream & 3;
bitCount += 2;
if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall);
while (charnum < n0) normalizedCounter[charnum++] = 0;
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
ip += bitCount>>3;
bitCount &= 7;
bitStream = MEM_readLE32(ip) >> bitCount;
} else {
bitStream >>= 2;
} }
{ int const max = (2*threshold-1) - remaining;
int count;
if ((bitStream & (threshold-1)) < (U32)max) {
count = bitStream & (threshold-1);
bitCount += nbBits-1;
} else {
count = bitStream & (2*threshold-1);
if (count >= threshold) count -= max;
bitCount += nbBits;
}
count--; /* extra accuracy */
remaining -= count < 0 ? -count : count; /* -1 means +1 */
normalizedCounter[charnum++] = (short)count;
previous0 = !count;
while (remaining < threshold) {
nbBits--;
threshold >>= 1;
}
if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
ip += bitCount>>3;
bitCount &= 7;
} else {
bitCount -= (int)(8 * (iend - 4 - ip));
ip = iend - 4;
}
bitStream = MEM_readLE32(ip) >> (bitCount & 31);
} } /* while ((remaining>1) & (charnum<=*maxSVPtr)) */
if (remaining != 1) return ERROR(corruption_detected);
if (bitCount > 32) return ERROR(corruption_detected);
*maxSVPtr = charnum-1;
ip += (bitCount+7)>>3;
return ip-istart;
}
/*! HUF_readStats() :
Read compact Huffman tree, saved by HUF_writeCTable().
`huffWeight` is destination buffer.
`rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
@return : size read from `src` , or an error Code .
Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
*/
size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
U32* nbSymbolsPtr, U32* tableLogPtr,
const void* src, size_t srcSize)
{
U32 weightTotal;
const BYTE* ip = (const BYTE*) src;
size_t iSize;
size_t oSize;
if (!srcSize) return ERROR(srcSize_wrong);
iSize = ip[0];
/* memset(huffWeight, 0, hwSize); *//* is not necessary, even though some analyzer complain ... */
if (iSize >= 128) { /* special header */
oSize = iSize - 127;
iSize = ((oSize+1)/2);
if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
if (oSize >= hwSize) return ERROR(corruption_detected);
ip += 1;
{ U32 n;
for (n=0; n<oSize; n+=2) {
huffWeight[n] = ip[n/2] >> 4;
huffWeight[n+1] = ip[n/2] & 15;
} } }
else { /* header compressed with FSE (normal case) */
FSE_DTable fseWorkspace[FSE_DTABLE_SIZE_U32(6)]; /* 6 is max possible tableLog for HUF header (maybe even 5, to be tested) */
if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
oSize = FSE_decompress_wksp(huffWeight, hwSize-1, ip+1, iSize, fseWorkspace, 6); /* max (hwSize-1) values decoded, as last one is implied */
if (FSE_isError(oSize)) return oSize;
}
/* collect weight stats */
memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
weightTotal = 0;
{ U32 n; for (n=0; n<oSize; n++) {
if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected);
rankStats[huffWeight[n]]++;
weightTotal += (1 << huffWeight[n]) >> 1;
} }
if (weightTotal == 0) return ERROR(corruption_detected);
/* get last non-null symbol weight (implied, total must be 2^n) */
{ U32 const tableLog = BIT_highbit32(weightTotal) + 1;
if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
*tableLogPtr = tableLog;
/* determine last weight */
{ U32 const total = 1 << tableLog;
U32 const rest = total - weightTotal;
U32 const verif = 1 << BIT_highbit32(rest);
U32 const lastWeight = BIT_highbit32(rest) + 1;
if (verif != rest) return ERROR(corruption_detected); /* last value must be a clean power of 2 */
huffWeight[oSize] = (BYTE)lastWeight;
rankStats[lastWeight]++;
} }
/* check tree construction validity */
if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected); /* by construction : at least 2 elts of rank 1, must be even */
/* results */
*nbSymbolsPtr = (U32)(oSize+1);
return iSize+1;
}

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/
/* The purpose of this file is to have a single list of error strings embedded in binary */
#include "error_private.h"
const char* ERR_getErrorString(ERR_enum code)
{
static const char* const notErrorCode = "Unspecified error code";
switch( code )
{
case PREFIX(no_error): return "No error detected";
case PREFIX(GENERIC): return "Error (generic)";
case PREFIX(prefix_unknown): return "Unknown frame descriptor";
case PREFIX(version_unsupported): return "Version not supported";
case PREFIX(frameParameter_unsupported): return "Unsupported frame parameter";
case PREFIX(frameParameter_windowTooLarge): return "Frame requires too much memory for decoding";
case PREFIX(corruption_detected): return "Corrupted block detected";
case PREFIX(checksum_wrong): return "Restored data doesn't match checksum";
case PREFIX(parameter_unsupported): return "Unsupported parameter";
case PREFIX(parameter_outOfBound): return "Parameter is out of bound";
case PREFIX(init_missing): return "Context should be init first";
case PREFIX(memory_allocation): return "Allocation error : not enough memory";
case PREFIX(workSpace_tooSmall): return "workSpace buffer is not large enough";
case PREFIX(stage_wrong): return "Operation not authorized at current processing stage";
case PREFIX(tableLog_tooLarge): return "tableLog requires too much memory : unsupported";
case PREFIX(maxSymbolValue_tooLarge): return "Unsupported max Symbol Value : too large";
case PREFIX(maxSymbolValue_tooSmall): return "Specified maxSymbolValue is too small";
case PREFIX(dictionary_corrupted): return "Dictionary is corrupted";
case PREFIX(dictionary_wrong): return "Dictionary mismatch";
case PREFIX(dictionaryCreation_failed): return "Cannot create Dictionary from provided samples";
case PREFIX(dstSize_tooSmall): return "Destination buffer is too small";
case PREFIX(srcSize_wrong): return "Src size is incorrect";
/* following error codes are not stable and may be removed or changed in a future version */
case PREFIX(frameIndex_tooLarge): return "Frame index is too large";
case PREFIX(seekableIO): return "An I/O error occurred when reading/seeking";
case PREFIX(maxCode):
default: return notErrorCode;
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/
/* Note : this module is expected to remain private, do not expose it */
#ifndef ERROR_H_MODULE
#define ERROR_H_MODULE
#if defined (__cplusplus)
extern "C" {
#endif
/* ****************************************
* Dependencies
******************************************/
#include <stddef.h> /* size_t */
#include "zstd_errors.h" /* enum list */
/* ****************************************
* Compiler-specific
******************************************/
#if defined(__GNUC__)
# define ERR_STATIC static __attribute__((unused))
#elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */)
# define ERR_STATIC static inline
#elif defined(_MSC_VER)
# define ERR_STATIC static __inline
#else
# define ERR_STATIC static /* this version may generate warnings for unused static functions; disable the relevant warning */
#endif
/*-****************************************
* Customization (error_public.h)
******************************************/
typedef ZSTD_ErrorCode ERR_enum;
#define PREFIX(name) ZSTD_error_##name
/*-****************************************
* Error codes handling
******************************************/
#undef ERROR /* reported already defined on VS 2015 (Rich Geldreich) */
#define ERROR(name) ZSTD_ERROR(name)
#define ZSTD_ERROR(name) ((size_t)-PREFIX(name))
ERR_STATIC unsigned ERR_isError(size_t code) { return (code > ERROR(maxCode)); }
ERR_STATIC ERR_enum ERR_getErrorCode(size_t code) { if (!ERR_isError(code)) return (ERR_enum)0; return (ERR_enum) (0-code); }
/*-****************************************
* Error Strings
******************************************/
const char* ERR_getErrorString(ERR_enum code); /* error_private.c */
ERR_STATIC const char* ERR_getErrorName(size_t code)
{
return ERR_getErrorString(ERR_getErrorCode(code));
}
#if defined (__cplusplus)
}
#endif
#endif /* ERROR_H_MODULE */

View File

@ -0,0 +1,35 @@
package zstd
/*
#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
*/
import "C"
// ErrorCode is an error returned by the zstd library.
type ErrorCode int
// Error returns the error string given by zstd
func (e ErrorCode) Error() string {
return C.GoString(C.ZSTD_getErrorName(C.size_t(e)))
}
func cIsError(code int) bool {
return int(C.ZSTD_isError(C.size_t(code))) != 0
}
// getError returns an error for the return code, or nil if it's not an error
func getError(code int) error {
if code < 0 && cIsError(code) {
return ErrorCode(code)
}
return nil
}
// IsDstSizeTooSmallError returns whether the error correspond to zstd standard sDstSizeTooSmall error
func IsDstSizeTooSmallError(e error) bool {
if e != nil && e.Error() == "Destination buffer is too small" {
return true
}
return false
}

View File

@ -0,0 +1,29 @@
package zstd
import (
"testing"
)
const (
// ErrorUpperBound is the upper bound to error number, currently only used in test
// If this needs to be updated, check in zstd_errors.h what the max is
ErrorUpperBound = 1000
)
// TestFindIsDstSizeTooSmallError tests that there is at least one error code that
// corresponds to dst size too small
func TestFindIsDstSizeTooSmallError(t *testing.T) {
found := 0
for i := -1; i > -ErrorUpperBound; i-- {
e := ErrorCode(i)
if IsDstSizeTooSmallError(e) {
found++
}
}
if found == 0 {
t.Fatal("Couldn't find an error code for DstSizeTooSmall error, please make sure we didn't change the error string")
} else if found > 1 {
t.Fatal("IsDstSizeTooSmallError found multiple error codes matching, this shouldn't be the case")
}
}

View File

@ -0,0 +1,704 @@
/* ******************************************************************
FSE : Finite State Entropy codec
Public Prototypes declaration
Copyright (C) 2013-2016, Yann Collet.
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
You can contact the author at :
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
****************************************************************** */
#if defined (__cplusplus)
extern "C" {
#endif
#ifndef FSE_H
#define FSE_H
/*-*****************************************
* Dependencies
******************************************/
#include <stddef.h> /* size_t, ptrdiff_t */
/*-*****************************************
* FSE_PUBLIC_API : control library symbols visibility
******************************************/
#if defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) && defined(__GNUC__) && (__GNUC__ >= 4)
# define FSE_PUBLIC_API __attribute__ ((visibility ("default")))
#elif defined(FSE_DLL_EXPORT) && (FSE_DLL_EXPORT==1) /* Visual expected */
# define FSE_PUBLIC_API __declspec(dllexport)
#elif defined(FSE_DLL_IMPORT) && (FSE_DLL_IMPORT==1)
# define FSE_PUBLIC_API __declspec(dllimport) /* It isn't required but allows to generate better code, saving a function pointer load from the IAT and an indirect jump.*/
#else
# define FSE_PUBLIC_API
#endif
/*------ Version ------*/
#define FSE_VERSION_MAJOR 0
#define FSE_VERSION_MINOR 9
#define FSE_VERSION_RELEASE 0
#define FSE_LIB_VERSION FSE_VERSION_MAJOR.FSE_VERSION_MINOR.FSE_VERSION_RELEASE
#define FSE_QUOTE(str) #str
#define FSE_EXPAND_AND_QUOTE(str) FSE_QUOTE(str)
#define FSE_VERSION_STRING FSE_EXPAND_AND_QUOTE(FSE_LIB_VERSION)
#define FSE_VERSION_NUMBER (FSE_VERSION_MAJOR *100*100 + FSE_VERSION_MINOR *100 + FSE_VERSION_RELEASE)
FSE_PUBLIC_API unsigned FSE_versionNumber(void); /**< library version number; to be used when checking dll version */
/*-****************************************
* FSE simple functions
******************************************/
/*! FSE_compress() :
Compress content of buffer 'src', of size 'srcSize', into destination buffer 'dst'.
'dst' buffer must be already allocated. Compression runs faster is dstCapacity >= FSE_compressBound(srcSize).
@return : size of compressed data (<= dstCapacity).
Special values : if return == 0, srcData is not compressible => Nothing is stored within dst !!!
if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression instead.
if FSE_isError(return), compression failed (more details using FSE_getErrorName())
*/
FSE_PUBLIC_API size_t FSE_compress(void* dst, size_t dstCapacity,
const void* src, size_t srcSize);
/*! FSE_decompress():
Decompress FSE data from buffer 'cSrc', of size 'cSrcSize',
into already allocated destination buffer 'dst', of size 'dstCapacity'.
@return : size of regenerated data (<= maxDstSize),
or an error code, which can be tested using FSE_isError() .
** Important ** : FSE_decompress() does not decompress non-compressible nor RLE data !!!
Why ? : making this distinction requires a header.
Header management is intentionally delegated to the user layer, which can better manage special cases.
*/
FSE_PUBLIC_API size_t FSE_decompress(void* dst, size_t dstCapacity,
const void* cSrc, size_t cSrcSize);
/*-*****************************************
* Tool functions
******************************************/
FSE_PUBLIC_API size_t FSE_compressBound(size_t size); /* maximum compressed size */
/* Error Management */
FSE_PUBLIC_API unsigned FSE_isError(size_t code); /* tells if a return value is an error code */
FSE_PUBLIC_API const char* FSE_getErrorName(size_t code); /* provides error code string (useful for debugging) */
/*-*****************************************
* FSE advanced functions
******************************************/
/*! FSE_compress2() :
Same as FSE_compress(), but allows the selection of 'maxSymbolValue' and 'tableLog'
Both parameters can be defined as '0' to mean : use default value
@return : size of compressed data
Special values : if return == 0, srcData is not compressible => Nothing is stored within cSrc !!!
if return == 1, srcData is a single byte symbol * srcSize times. Use RLE compression.
if FSE_isError(return), it's an error code.
*/
FSE_PUBLIC_API size_t FSE_compress2 (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog);
/*-*****************************************
* FSE detailed API
******************************************/
/*!
FSE_compress() does the following:
1. count symbol occurrence from source[] into table count[]
2. normalize counters so that sum(count[]) == Power_of_2 (2^tableLog)
3. save normalized counters to memory buffer using writeNCount()
4. build encoding table 'CTable' from normalized counters
5. encode the data stream using encoding table 'CTable'
FSE_decompress() does the following:
1. read normalized counters with readNCount()
2. build decoding table 'DTable' from normalized counters
3. decode the data stream using decoding table 'DTable'
The following API allows targeting specific sub-functions for advanced tasks.
For example, it's possible to compress several blocks using the same 'CTable',
or to save and provide normalized distribution using external method.
*/
/* *** COMPRESSION *** */
/*! FSE_count():
Provides the precise count of each byte within a table 'count'.
'count' is a table of unsigned int, of minimum size (*maxSymbolValuePtr+1).
*maxSymbolValuePtr will be updated if detected smaller than initial value.
@return : the count of the most frequent symbol (which is not identified).
if return == srcSize, there is only one symbol.
Can also return an error code, which can be tested with FSE_isError(). */
FSE_PUBLIC_API size_t FSE_count(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
/*! FSE_optimalTableLog():
dynamically downsize 'tableLog' when conditions are met.
It saves CPU time, by using smaller tables, while preserving or even improving compression ratio.
@return : recommended tableLog (necessarily <= 'maxTableLog') */
FSE_PUBLIC_API unsigned FSE_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue);
/*! FSE_normalizeCount():
normalize counts so that sum(count[]) == Power_of_2 (2^tableLog)
'normalizedCounter' is a table of short, of minimum size (maxSymbolValue+1).
@return : tableLog,
or an errorCode, which can be tested using FSE_isError() */
FSE_PUBLIC_API size_t FSE_normalizeCount(short* normalizedCounter, unsigned tableLog, const unsigned* count, size_t srcSize, unsigned maxSymbolValue);
/*! FSE_NCountWriteBound():
Provides the maximum possible size of an FSE normalized table, given 'maxSymbolValue' and 'tableLog'.
Typically useful for allocation purpose. */
FSE_PUBLIC_API size_t FSE_NCountWriteBound(unsigned maxSymbolValue, unsigned tableLog);
/*! FSE_writeNCount():
Compactly save 'normalizedCounter' into 'buffer'.
@return : size of the compressed table,
or an errorCode, which can be tested using FSE_isError(). */
FSE_PUBLIC_API size_t FSE_writeNCount (void* buffer, size_t bufferSize, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
/*! Constructor and Destructor of FSE_CTable.
Note that FSE_CTable size depends on 'tableLog' and 'maxSymbolValue' */
typedef unsigned FSE_CTable; /* don't allocate that. It's only meant to be more restrictive than void* */
FSE_PUBLIC_API FSE_CTable* FSE_createCTable (unsigned maxSymbolValue, unsigned tableLog);
FSE_PUBLIC_API void FSE_freeCTable (FSE_CTable* ct);
/*! FSE_buildCTable():
Builds `ct`, which must be already allocated, using FSE_createCTable().
@return : 0, or an errorCode, which can be tested using FSE_isError() */
FSE_PUBLIC_API size_t FSE_buildCTable(FSE_CTable* ct, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
/*! FSE_compress_usingCTable():
Compress `src` using `ct` into `dst` which must be already allocated.
@return : size of compressed data (<= `dstCapacity`),
or 0 if compressed data could not fit into `dst`,
or an errorCode, which can be tested using FSE_isError() */
FSE_PUBLIC_API size_t FSE_compress_usingCTable (void* dst, size_t dstCapacity, const void* src, size_t srcSize, const FSE_CTable* ct);
/*!
Tutorial :
----------
The first step is to count all symbols. FSE_count() does this job very fast.
Result will be saved into 'count', a table of unsigned int, which must be already allocated, and have 'maxSymbolValuePtr[0]+1' cells.
'src' is a table of bytes of size 'srcSize'. All values within 'src' MUST be <= maxSymbolValuePtr[0]
maxSymbolValuePtr[0] will be updated, with its real value (necessarily <= original value)
FSE_count() will return the number of occurrence of the most frequent symbol.
This can be used to know if there is a single symbol within 'src', and to quickly evaluate its compressibility.
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()).
The next step is to normalize the frequencies.
FSE_normalizeCount() will ensure that sum of frequencies is == 2 ^'tableLog'.
It also guarantees a minimum of 1 to any Symbol with frequency >= 1.
You can use 'tableLog'==0 to mean "use default tableLog value".
If you are unsure of which tableLog value to use, you can ask FSE_optimalTableLog(),
which will provide the optimal valid tableLog given sourceSize, maxSymbolValue, and a user-defined maximum (0 means "default").
The result of FSE_normalizeCount() will be saved into a table,
called 'normalizedCounter', which is a table of signed short.
'normalizedCounter' must be already allocated, and have at least 'maxSymbolValue+1' cells.
The return value is tableLog if everything proceeded as expected.
It is 0 if there is a single symbol within distribution.
If there is an error (ex: invalid tableLog value), the function will return an ErrorCode (which can be tested using FSE_isError()).
'normalizedCounter' can be saved in a compact manner to a memory area using FSE_writeNCount().
'buffer' must be already allocated.
For guaranteed success, buffer size must be at least FSE_headerBound().
The result of the function is the number of bytes written into 'buffer'.
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError(); ex : buffer size too small).
'normalizedCounter' can then be used to create the compression table 'CTable'.
The space required by 'CTable' must be already allocated, using FSE_createCTable().
You can then use FSE_buildCTable() to fill 'CTable'.
If there is an error, both functions will return an ErrorCode (which can be tested using FSE_isError()).
'CTable' can then be used to compress 'src', with FSE_compress_usingCTable().
Similar to FSE_count(), the convention is that 'src' is assumed to be a table of char of size 'srcSize'
The function returns the size of compressed data (without header), necessarily <= `dstCapacity`.
If it returns '0', compressed data could not fit into 'dst'.
If there is an error, the function will return an ErrorCode (which can be tested using FSE_isError()).
*/
/* *** DECOMPRESSION *** */
/*! FSE_readNCount():
Read compactly saved 'normalizedCounter' from 'rBuffer'.
@return : size read from 'rBuffer',
or an errorCode, which can be tested using FSE_isError().
maxSymbolValuePtr[0] and tableLogPtr[0] will also be updated with their respective values */
FSE_PUBLIC_API size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSymbolValuePtr, unsigned* tableLogPtr, const void* rBuffer, size_t rBuffSize);
/*! Constructor and Destructor of FSE_DTable.
Note that its size depends on 'tableLog' */
typedef unsigned FSE_DTable; /* don't allocate that. It's just a way to be more restrictive than void* */
FSE_PUBLIC_API FSE_DTable* FSE_createDTable(unsigned tableLog);
FSE_PUBLIC_API void FSE_freeDTable(FSE_DTable* dt);
/*! FSE_buildDTable():
Builds 'dt', which must be already allocated, using FSE_createDTable().
return : 0, or an errorCode, which can be tested using FSE_isError() */
FSE_PUBLIC_API size_t FSE_buildDTable (FSE_DTable* dt, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog);
/*! FSE_decompress_usingDTable():
Decompress compressed source `cSrc` of size `cSrcSize` using `dt`
into `dst` which must be already allocated.
@return : size of regenerated data (necessarily <= `dstCapacity`),
or an errorCode, which can be tested using FSE_isError() */
FSE_PUBLIC_API size_t FSE_decompress_usingDTable(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, const FSE_DTable* dt);
/*!
Tutorial :
----------
(Note : these functions only decompress FSE-compressed blocks.
If block is uncompressed, use memcpy() instead
If block is a single repeated byte, use memset() instead )
The first step is to obtain the normalized frequencies of symbols.
This can be performed by FSE_readNCount() if it was saved using FSE_writeNCount().
'normalizedCounter' must be already allocated, and have at least 'maxSymbolValuePtr[0]+1' cells of signed short.
In practice, that means it's necessary to know 'maxSymbolValue' beforehand,
or size the table to handle worst case situations (typically 256).
FSE_readNCount() will provide 'tableLog' and 'maxSymbolValue'.
The result of FSE_readNCount() is the number of bytes read from 'rBuffer'.
Note that 'rBufferSize' must be at least 4 bytes, even if useful information is less than that.
If there is an error, the function will return an error code, which can be tested using FSE_isError().
The next step is to build the decompression tables 'FSE_DTable' from 'normalizedCounter'.
This is performed by the function FSE_buildDTable().
The space required by 'FSE_DTable' must be already allocated using FSE_createDTable().
If there is an error, the function will return an error code, which can be tested using FSE_isError().
`FSE_DTable` can then be used to decompress `cSrc`, with FSE_decompress_usingDTable().
`cSrcSize` must be strictly correct, otherwise decompression will fail.
FSE_decompress_usingDTable() result will tell how many bytes were regenerated (<=`dstCapacity`).
If there is an error, the function will return an error code, which can be tested using FSE_isError(). (ex: dst buffer too small)
*/
#endif /* FSE_H */
#if defined(FSE_STATIC_LINKING_ONLY) && !defined(FSE_H_FSE_STATIC_LINKING_ONLY)
#define FSE_H_FSE_STATIC_LINKING_ONLY
/* *** Dependency *** */
#include "bitstream.h"
/* *****************************************
* Static allocation
*******************************************/
/* FSE buffer bounds */
#define FSE_NCOUNTBOUND 512
#define FSE_BLOCKBOUND(size) (size + (size>>7))
#define FSE_COMPRESSBOUND(size) (FSE_NCOUNTBOUND + FSE_BLOCKBOUND(size)) /* Macro version, useful for static allocation */
/* It is possible to statically allocate FSE CTable/DTable as a table of FSE_CTable/FSE_DTable using below macros */
#define FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) (1 + (1<<(maxTableLog-1)) + ((maxSymbolValue+1)*2))
#define FSE_DTABLE_SIZE_U32(maxTableLog) (1 + (1<<maxTableLog))
/* or use the size to malloc() space directly. Pay attention to alignment restrictions though */
#define FSE_CTABLE_SIZE(maxTableLog, maxSymbolValue) (FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) * sizeof(FSE_CTable))
#define FSE_DTABLE_SIZE(maxTableLog) (FSE_DTABLE_SIZE_U32(maxTableLog) * sizeof(FSE_DTable))
/* *****************************************
* FSE advanced API
*******************************************/
/* FSE_count_wksp() :
* Same as FSE_count(), but using an externally provided scratch buffer.
* `workSpace` size must be table of >= `1024` unsigned
*/
size_t FSE_count_wksp(unsigned* count, unsigned* maxSymbolValuePtr,
const void* source, size_t sourceSize, unsigned* workSpace);
/** FSE_countFast() :
* same as FSE_count(), but blindly trusts that all byte values within src are <= *maxSymbolValuePtr
*/
size_t FSE_countFast(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
/* FSE_countFast_wksp() :
* Same as FSE_countFast(), but using an externally provided scratch buffer.
* `workSpace` must be a table of minimum `1024` unsigned
*/
size_t FSE_countFast_wksp(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned* workSpace);
/*! FSE_count_simple() :
* Same as FSE_countFast(), but does not use any additional memory (not even on stack).
* This function is unsafe, and will segfault if any value within `src` is `> *maxSymbolValuePtr` (presuming it's also the size of `count`).
*/
size_t FSE_count_simple(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize);
unsigned FSE_optimalTableLog_internal(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, unsigned minus);
/**< same as FSE_optimalTableLog(), which used `minus==2` */
/* FSE_compress_wksp() :
* Same as FSE_compress2(), but using an externally allocated scratch buffer (`workSpace`).
* FSE_WKSP_SIZE_U32() provides the minimum size required for `workSpace` as a table of FSE_CTable.
*/
#define FSE_WKSP_SIZE_U32(maxTableLog, maxSymbolValue) ( FSE_CTABLE_SIZE_U32(maxTableLog, maxSymbolValue) + ((maxTableLog > 12) ? (1 << (maxTableLog - 2)) : 1024) )
size_t FSE_compress_wksp (void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize);
size_t FSE_buildCTable_raw (FSE_CTable* ct, unsigned nbBits);
/**< build a fake FSE_CTable, designed for a flat distribution, where each symbol uses nbBits */
size_t FSE_buildCTable_rle (FSE_CTable* ct, unsigned char symbolValue);
/**< build a fake FSE_CTable, designed to compress always the same symbolValue */
/* FSE_buildCTable_wksp() :
* Same as FSE_buildCTable(), but using an externally allocated scratch buffer (`workSpace`).
* `wkspSize` must be >= `(1<<tableLog)`.
*/
size_t FSE_buildCTable_wksp(FSE_CTable* ct, const short* normalizedCounter, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize);
size_t FSE_buildDTable_raw (FSE_DTable* dt, unsigned nbBits);
/**< build a fake FSE_DTable, designed to read a flat distribution where each symbol uses nbBits */
size_t FSE_buildDTable_rle (FSE_DTable* dt, unsigned char symbolValue);
/**< build a fake FSE_DTable, designed to always generate the same symbolValue */
size_t FSE_decompress_wksp(void* dst, size_t dstCapacity, const void* cSrc, size_t cSrcSize, FSE_DTable* workSpace, unsigned maxLog);
/**< same as FSE_decompress(), using an externally allocated `workSpace` produced with `FSE_DTABLE_SIZE_U32(maxLog)` */
typedef enum {
FSE_repeat_none, /**< Cannot use the previous table */
FSE_repeat_check, /**< Can use the previous table but it must be checked */
FSE_repeat_valid /**< Can use the previous table and it is asumed to be valid */
} FSE_repeat;
/* *****************************************
* FSE symbol compression API
*******************************************/
/*!
This API consists of small unitary functions, which highly benefit from being inlined.
Hence their body are included in next section.
*/
typedef struct {
ptrdiff_t value;
const void* stateTable;
const void* symbolTT;
unsigned stateLog;
} FSE_CState_t;
static void FSE_initCState(FSE_CState_t* CStatePtr, const FSE_CTable* ct);
static void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* CStatePtr, unsigned symbol);
static void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* CStatePtr);
/**<
These functions are inner components of FSE_compress_usingCTable().
They allow the creation of custom streams, mixing multiple tables and bit sources.
A key property to keep in mind is that encoding and decoding are done **in reverse direction**.
So the first symbol you will encode is the last you will decode, like a LIFO stack.
You will need a few variables to track your CStream. They are :
FSE_CTable ct; // Provided by FSE_buildCTable()
BIT_CStream_t bitStream; // bitStream tracking structure
FSE_CState_t state; // State tracking structure (can have several)
The first thing to do is to init bitStream and state.
size_t errorCode = BIT_initCStream(&bitStream, dstBuffer, maxDstSize);
FSE_initCState(&state, ct);
Note that BIT_initCStream() can produce an error code, so its result should be tested, using FSE_isError();
You can then encode your input data, byte after byte.
FSE_encodeSymbol() outputs a maximum of 'tableLog' bits at a time.
Remember decoding will be done in reverse direction.
FSE_encodeByte(&bitStream, &state, symbol);
At any time, you can also add any bit sequence.
Note : maximum allowed nbBits is 25, for compatibility with 32-bits decoders
BIT_addBits(&bitStream, bitField, nbBits);
The above methods don't commit data to memory, they just store it into local register, for speed.
Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t).
Writing data to memory is a manual operation, performed by the flushBits function.
BIT_flushBits(&bitStream);
Your last FSE encoding operation shall be to flush your last state value(s).
FSE_flushState(&bitStream, &state);
Finally, you must close the bitStream.
The function returns the size of CStream in bytes.
If data couldn't fit into dstBuffer, it will return a 0 ( == not compressible)
If there is an error, it returns an errorCode (which can be tested using FSE_isError()).
size_t size = BIT_closeCStream(&bitStream);
*/
/* *****************************************
* FSE symbol decompression API
*******************************************/
typedef struct {
size_t state;
const void* table; /* precise table may vary, depending on U16 */
} FSE_DState_t;
static void FSE_initDState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD, const FSE_DTable* dt);
static unsigned char FSE_decodeSymbol(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD);
static unsigned FSE_endOfDState(const FSE_DState_t* DStatePtr);
/**<
Let's now decompose FSE_decompress_usingDTable() into its unitary components.
You will decode FSE-encoded symbols from the bitStream,
and also any other bitFields you put in, **in reverse order**.
You will need a few variables to track your bitStream. They are :
BIT_DStream_t DStream; // Stream context
FSE_DState_t DState; // State context. Multiple ones are possible
FSE_DTable* DTablePtr; // Decoding table, provided by FSE_buildDTable()
The first thing to do is to init the bitStream.
errorCode = BIT_initDStream(&DStream, srcBuffer, srcSize);
You should then retrieve your initial state(s)
(in reverse flushing order if you have several ones) :
errorCode = FSE_initDState(&DState, &DStream, DTablePtr);
You can then decode your data, symbol after symbol.
For information the maximum number of bits read by FSE_decodeSymbol() is 'tableLog'.
Keep in mind that symbols are decoded in reverse order, like a LIFO stack (last in, first out).
unsigned char symbol = FSE_decodeSymbol(&DState, &DStream);
You can retrieve any bitfield you eventually stored into the bitStream (in reverse order)
Note : maximum allowed nbBits is 25, for 32-bits compatibility
size_t bitField = BIT_readBits(&DStream, nbBits);
All above operations only read from local register (which size depends on size_t).
Refueling the register from memory is manually performed by the reload method.
endSignal = FSE_reloadDStream(&DStream);
BIT_reloadDStream() result tells if there is still some more data to read from DStream.
BIT_DStream_unfinished : there is still some data left into the DStream.
BIT_DStream_endOfBuffer : Dstream reached end of buffer. Its container may no longer be completely filled.
BIT_DStream_completed : Dstream reached its exact end, corresponding in general to decompression completed.
BIT_DStream_tooFar : Dstream went too far. Decompression result is corrupted.
When reaching end of buffer (BIT_DStream_endOfBuffer), progress slowly, notably if you decode multiple symbols per loop,
to properly detect the exact end of stream.
After each decoded symbol, check if DStream is fully consumed using this simple test :
BIT_reloadDStream(&DStream) >= BIT_DStream_completed
When it's done, verify decompression is fully completed, by checking both DStream and the relevant states.
Checking if DStream has reached its end is performed by :
BIT_endOfDStream(&DStream);
Check also the states. There might be some symbols left there, if some high probability ones (>50%) are possible.
FSE_endOfDState(&DState);
*/
/* *****************************************
* FSE unsafe API
*******************************************/
static unsigned char FSE_decodeSymbolFast(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD);
/* faster, but works only if nbBits is always >= 1 (otherwise, result will be corrupted) */
/* *****************************************
* Implementation of inlined functions
*******************************************/
typedef struct {
int deltaFindState;
U32 deltaNbBits;
} FSE_symbolCompressionTransform; /* total 8 bytes */
MEM_STATIC void FSE_initCState(FSE_CState_t* statePtr, const FSE_CTable* ct)
{
const void* ptr = ct;
const U16* u16ptr = (const U16*) ptr;
const U32 tableLog = MEM_read16(ptr);
statePtr->value = (ptrdiff_t)1<<tableLog;
statePtr->stateTable = u16ptr+2;
statePtr->symbolTT = ((const U32*)ct + 1 + (tableLog ? (1<<(tableLog-1)) : 1));
statePtr->stateLog = tableLog;
}
/*! FSE_initCState2() :
* Same as FSE_initCState(), but the first symbol to include (which will be the last to be read)
* uses the smallest state value possible, saving the cost of this symbol */
MEM_STATIC void FSE_initCState2(FSE_CState_t* statePtr, const FSE_CTable* ct, U32 symbol)
{
FSE_initCState(statePtr, ct);
{ const FSE_symbolCompressionTransform symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol];
const U16* stateTable = (const U16*)(statePtr->stateTable);
U32 nbBitsOut = (U32)((symbolTT.deltaNbBits + (1<<15)) >> 16);
statePtr->value = (nbBitsOut << 16) - symbolTT.deltaNbBits;
statePtr->value = stateTable[(statePtr->value >> nbBitsOut) + symbolTT.deltaFindState];
}
}
MEM_STATIC void FSE_encodeSymbol(BIT_CStream_t* bitC, FSE_CState_t* statePtr, U32 symbol)
{
FSE_symbolCompressionTransform const symbolTT = ((const FSE_symbolCompressionTransform*)(statePtr->symbolTT))[symbol];
const U16* const stateTable = (const U16*)(statePtr->stateTable);
U32 const nbBitsOut = (U32)((statePtr->value + symbolTT.deltaNbBits) >> 16);
BIT_addBits(bitC, statePtr->value, nbBitsOut);
statePtr->value = stateTable[ (statePtr->value >> nbBitsOut) + symbolTT.deltaFindState];
}
MEM_STATIC void FSE_flushCState(BIT_CStream_t* bitC, const FSE_CState_t* statePtr)
{
BIT_addBits(bitC, statePtr->value, statePtr->stateLog);
BIT_flushBits(bitC);
}
/* ====== Decompression ====== */
typedef struct {
U16 tableLog;
U16 fastMode;
} FSE_DTableHeader; /* sizeof U32 */
typedef struct
{
unsigned short newState;
unsigned char symbol;
unsigned char nbBits;
} FSE_decode_t; /* size == U32 */
MEM_STATIC void FSE_initDState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD, const FSE_DTable* dt)
{
const void* ptr = dt;
const FSE_DTableHeader* const DTableH = (const FSE_DTableHeader*)ptr;
DStatePtr->state = BIT_readBits(bitD, DTableH->tableLog);
BIT_reloadDStream(bitD);
DStatePtr->table = dt + 1;
}
MEM_STATIC BYTE FSE_peekSymbol(const FSE_DState_t* DStatePtr)
{
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
return DInfo.symbol;
}
MEM_STATIC void FSE_updateState(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
{
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
U32 const nbBits = DInfo.nbBits;
size_t const lowBits = BIT_readBits(bitD, nbBits);
DStatePtr->state = DInfo.newState + lowBits;
}
MEM_STATIC BYTE FSE_decodeSymbol(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
{
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
U32 const nbBits = DInfo.nbBits;
BYTE const symbol = DInfo.symbol;
size_t const lowBits = BIT_readBits(bitD, nbBits);
DStatePtr->state = DInfo.newState + lowBits;
return symbol;
}
/*! FSE_decodeSymbolFast() :
unsafe, only works if no symbol has a probability > 50% */
MEM_STATIC BYTE FSE_decodeSymbolFast(FSE_DState_t* DStatePtr, BIT_DStream_t* bitD)
{
FSE_decode_t const DInfo = ((const FSE_decode_t*)(DStatePtr->table))[DStatePtr->state];
U32 const nbBits = DInfo.nbBits;
BYTE const symbol = DInfo.symbol;
size_t const lowBits = BIT_readBitsFast(bitD, nbBits);
DStatePtr->state = DInfo.newState + lowBits;
return symbol;
}
MEM_STATIC unsigned FSE_endOfDState(const FSE_DState_t* DStatePtr)
{
return DStatePtr->state == 0;
}
#ifndef FSE_COMMONDEFS_ONLY
/* **************************************************************
* Tuning parameters
****************************************************************/
/*!MEMORY_USAGE :
* Memory usage formula : N->2^N Bytes (examples : 10 -> 1KB; 12 -> 4KB ; 16 -> 64KB; 20 -> 1MB; etc.)
* Increasing memory usage improves compression ratio
* Reduced memory usage can improve speed, due to cache effect
* Recommended max value is 14, for 16KB, which nicely fits into Intel x86 L1 cache */
#ifndef FSE_MAX_MEMORY_USAGE
# define FSE_MAX_MEMORY_USAGE 14
#endif
#ifndef FSE_DEFAULT_MEMORY_USAGE
# define FSE_DEFAULT_MEMORY_USAGE 13
#endif
/*!FSE_MAX_SYMBOL_VALUE :
* Maximum symbol value authorized.
* Required for proper stack allocation */
#ifndef FSE_MAX_SYMBOL_VALUE
# define FSE_MAX_SYMBOL_VALUE 255
#endif
/* **************************************************************
* template functions type & suffix
****************************************************************/
#define FSE_FUNCTION_TYPE BYTE
#define FSE_FUNCTION_EXTENSION
#define FSE_DECODE_TYPE FSE_decode_t
#endif /* !FSE_COMMONDEFS_ONLY */
/* ***************************************************************
* Constants
*****************************************************************/
#define FSE_MAX_TABLELOG (FSE_MAX_MEMORY_USAGE-2)
#define FSE_MAX_TABLESIZE (1U<<FSE_MAX_TABLELOG)
#define FSE_MAXTABLESIZE_MASK (FSE_MAX_TABLESIZE-1)
#define FSE_DEFAULT_TABLELOG (FSE_DEFAULT_MEMORY_USAGE-2)
#define FSE_MIN_TABLELOG 5
#define FSE_TABLELOG_ABSOLUTE_MAX 15
#if FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX
# error "FSE_MAX_TABLELOG > FSE_TABLELOG_ABSOLUTE_MAX is not supported"
#endif
#define FSE_TABLESTEP(tableSize) ((tableSize>>1) + (tableSize>>3) + 3)
#endif /* FSE_STATIC_LINKING_ONLY */
#if defined (__cplusplus)
}
#endif

Some files were not shown because too many files have changed in this diff Show More