diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index bc532c59b..cd7900d9b 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -53,6 +53,9 @@ type Model struct { onDuplicateEx interface{} // onDuplicateEx is used for excluding some columns ON "DUPLICATE KEY UPDATE" statement. } +// ModelHandler is a function that handles given Model and returns a new Model that is custom modified. +type ModelHandler func(m *Model) *Model + // whereHolder is the holder for where condition preparing. type whereHolder struct { operator int // Operator for this holder. @@ -297,3 +300,13 @@ func (m *Model) Args(args ...interface{}) *Model { model.extraArgs = append(model.extraArgs, args) return model } + +// Handler calls each of `handlers` on current Model and returns a new Model. +// ModelHandler is a function that handles given Model and returns a new Model that is custom modified. +func (m *Model) Handler(handlers ...ModelHandler) *Model { + model := m.getModel() + for _, handler := range handlers { + model = handler(model) + } + return model +} diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index dc83d4d95..8121c2d49 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -3729,3 +3729,27 @@ func Test_Model_Raw(t *testing.T) { t.Assert(count, 6) }) } + +func Test_Model_Handler(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Safe().Handler( + func(m *gdb.Model) *gdb.Model { + return m.Page(0, 3) + }, + func(m *gdb.Model) *gdb.Model { + return m.Where("id", g.Slice{1, 2, 3, 4, 5, 6}) + }, + func(m *gdb.Model) *gdb.Model { + return m.OrderDesc("id") + }, + ) + all, err := m.All() + t.AssertNil(err) + t.Assert(len(all), 3) + t.Assert(all[0]["id"], 6) + t.Assert(all[2]["id"], 4) + }) +} diff --git a/os/glog/glog_z_unit_chaining_test.go b/os/glog/glog_z_unit_chaining_test.go index f6110e491..1be9d3a86 100644 --- a/os/glog/glog_z_unit_chaining_test.go +++ b/os/glog/glog_z_unit_chaining_test.go @@ -145,6 +145,8 @@ func Test_StackWithFilter(t *testing.T) { t.Assert(gstr.Count(content, defaultLevelPrefixes[LEVEL_ERRO]), 1) t.Assert(gstr.Count(content, "1 2 3"), 1) t.Assert(gstr.Count(content, "Stack"), 0) + fmt.Println("Content:") + fmt.Println(content) }) }