diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 4442b5cb9..8407c7178 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -5,9 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package gdb provides ORM features for popular relationship databases. -// -// 数据库ORM, -// 默认内置支持MySQL, 其他数据库需要手动import对应的数据库引擎第三方包. package gdb import ( @@ -36,6 +33,7 @@ type DB interface { // 内部实现API的方法(不同数据库可覆盖这些方法实现自定义的操作) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) + doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) doPrepare(link dbLink, query string) (*sql.Stmt, error) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) diff --git a/database/gdb/gdb_base.go b/database/gdb/gdb_base.go index 73679282a..07a5953a6 100644 --- a/database/gdb/gdb_base.go +++ b/database/gdb/gdb_base.go @@ -169,7 +169,18 @@ func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) { // 数据库查询,获取查询结果集,以列表结构返回 func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) { - rows, err := bs.Query(query, args...) + return bs.db.doGetAll(nil, query, args...) +} + +// 数据库查询,获取查询结果集,以列表结构返回,给定连接对象 +func (bs *dbBase) doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) { + if link == nil { + link, err = bs.db.Slave() + if err != nil { + return nil, err + } + } + rows, err := bs.doQuery(link, query, args...) if err != nil || rows == nil { return nil, err } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index dbbb641e7..22eae4b81 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -19,6 +19,7 @@ import ( type Model struct { db DB // 数据库操作对象 tx *TX // 数据库事务对象 + linkType int // 连接对象类型(用于主从集群时开发者自定义操作对象) tablesInit string // 初始化Model时的表名称(可以是多个) tables string // 数据库操作表 fields string // 操作字段 @@ -38,6 +39,11 @@ type Model struct { safe bool // 当前模型是否安全模式(默认非安全表示链式操作直接修改当前模型属性;否则每一次链式操作都是返回新的模型对象) } +const ( + gLINK_TYPE_MASTER = 1 // 主节点类型 + gLINK_TYPE_SLAVE = 2 // 从节点类型 +) + // 链式操作,数据表字段,可支持多个表,以半角逗号连接 func (bs *dbBase) Table(tables string) *Model { return &Model{ @@ -87,6 +93,20 @@ func (md *Model) Clone() *Model { return newModel } +// 设置本次链式操作在主节点上 +func (md *Model) Master() *Model { + model := md.getModel() + model.linkType = gLINK_TYPE_MASTER + return model +} + +// 设置本次链式操作在从节点上 +func (md *Model) Slave() *Model { + model := md.getModel() + model.linkType = gLINK_TYPE_SLAVE + return model +} + // 标识当前对象运行安全模式(可被修改)。 // 1. 默认情况下,模型对象的对象属性无法被修改, // 每一次链式操作都是克隆一个新的模型对象,这样所有的操作都不会污染模型对象。 @@ -329,20 +349,12 @@ func (md *Model) Insert() (result sql.Result, err error) { list[k] = md.db.filterFields(md.tables, m) } } - if md.tx == nil { - return md.db.BatchInsert(md.tables, list, batch) - } else { - return md.tx.BatchInsert(md.tables, list, batch) - } + return md.db.doBatchInsert(md.getLink(), md.tables, list, OPTION_INSERT, batch) } else if data, ok := md.data.(Map); ok { if md.filter { data = md.db.filterFields(md.tables, data) } - if md.tx == nil { - return md.db.Insert(md.tables, data) - } else { - return md.tx.Insert(md.tables, data) - } + return md.db.doInsert(md.getLink(), md.tables, data, OPTION_INSERT) } return nil, errors.New("inserting into table with invalid data type") } @@ -370,20 +382,12 @@ func (md *Model) Replace() (result sql.Result, err error) { list[k] = md.db.filterFields(md.tables, m) } } - if md.tx == nil { - return md.db.BatchReplace(md.tables, list, batch) - } else { - return md.tx.BatchReplace(md.tables, list, batch) - } + return md.db.doBatchInsert(md.getLink(), md.tables, list, OPTION_REPLACE, batch) } else if data, ok := md.data.(Map); ok { if md.filter { data = md.db.filterFields(md.tables, data) } - if md.tx == nil { - return md.db.Replace(md.tables, data) - } else { - return md.tx.Replace(md.tables, data) - } + return md.db.doInsert(md.getLink(), md.tables, data, OPTION_REPLACE) } return nil, errors.New("replacing into table with invalid data type") } @@ -411,20 +415,12 @@ func (md *Model) Save() (result sql.Result, err error) { list[k] = md.db.filterFields(md.tables, m) } } - if md.tx == nil { - return md.db.BatchSave(md.tables, list, batch) - } else { - return md.tx.BatchSave(md.tables, list, batch) - } + return md.db.doBatchInsert(md.getLink(), md.tables, list, OPTION_SAVE, batch) } else if data, ok := md.data.(Map); ok { if md.filter { data = md.db.filterFields(md.tables, data) } - if md.tx == nil { - return md.db.Save(md.tables, data) - } else { - return md.tx.Save(md.tables, data) - } + return md.db.doInsert(md.getLink(), md.tables, data, OPTION_SAVE) } return nil, errors.New("saving into table with invalid data type") } @@ -446,11 +442,7 @@ func (md *Model) Update() (result sql.Result, err error) { } } } - if md.tx == nil { - return md.db.doUpdate(nil, md.tables, md.data, md.getConditionSql(), md.whereArgs...) - } else { - return md.tx.doUpdate(md.tables, md.data, md.getConditionSql(), md.whereArgs...) - } + return md.db.doUpdate(md.getLink(), md.tables, md.data, md.getConditionSql(), md.whereArgs...) } // 链式操作, CURD - Delete @@ -460,11 +452,7 @@ func (md *Model) Delete() (result sql.Result, err error) { md.checkAndRemoveCache() } }() - if md.tx == nil { - return md.db.doDelete(nil, md.tables, md.getConditionSql(), md.whereArgs...) - } else { - return md.tx.doDelete(md.tables, md.getConditionSql(), md.whereArgs...) - } + return md.db.doDelete(md.getLink(), md.tables, md.getConditionSql(), md.whereArgs...) } // 链式操作,select @@ -565,6 +553,22 @@ func (md *Model) Count() (int, error) { return 0, nil } +// 获得操作的连接对象 +func (md *Model) getLink() dbLink { + if md.tx != nil { + return md.tx.tx + } + switch md.linkType { + case gLINK_TYPE_MASTER: + link, _ := md.db.Master() + return link + case gLINK_TYPE_SLAVE: + link, _ := md.db.Slave() + return link + } + return nil +} + // 查询操作,对底层SQL操作的封装 func (md *Model) getAll(query string, args ...interface{}) (result Result, err error) { cacheKey := "" @@ -578,12 +582,7 @@ func (md *Model) getAll(query string, args ...interface{}) (result Result, err e return v.(Result), nil } } - - if md.tx == nil { - result, err = md.db.GetAll(query, args...) - } else { - result, err = md.tx.GetAll(query, args...) - } + result, err = md.db.doGetAll(md.getLink(), query, args...) // 查询缓存保存处理 if len(cacheKey) > 0 && err == nil { if md.cacheTime < 0 {