This commit is contained in:
redhat 2025-06-09 19:54:37 +08:00
commit 37491afb03
3 changed files with 185 additions and 0 deletions

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.zhangshuocauc.cn/redhat/timewheel
go 1.24.3

158
timewheel.go Normal file
View File

@ -0,0 +1,158 @@
package timewheel
import (
"container/list"
"sync"
"time"
)
type taskSlice struct {
id string
pos int
cycle int
task func()
}
type TimeWheel struct {
sync.Once
curtSlot int
interval time.Duration
ticker *time.Ticker
slots []*list.List
taskMap map[string]*list.Element
stopC chan struct{}
addTaskC chan *taskSlice
removeTaskC chan string
}
func New(slots int, interval time.Duration) *TimeWheel {
if slots <= 0 {
slots = 10
}
if interval <= 0 {
interval = time.Second
}
t := &TimeWheel{
interval: interval,
ticker: time.NewTicker(interval),
slots: make([]*list.List, 0, slots),
taskMap: map[string]*list.Element{},
stopC: make(chan struct{}),
addTaskC: make(chan *taskSlice),
removeTaskC: make(chan string),
}
for i := 0; i < slots; i++ {
t.slots = append(t.slots, list.New())
}
go t.run()
return t
}
func (t *TimeWheel) AddTask(id string, task func(), flower time.Time) {
cycle, pos := t.getSlotPosAndCycle(flower)
t.addTaskC <- &taskSlice{
id: id,
task: task,
cycle: cycle,
pos: pos,
}
}
func (t *TimeWheel) RemoveTask(key string) {
t.removeTaskC <- key
}
func (t *TimeWheel) Stop() {
t.Do(func() {
t.ticker.Stop()
close(t.stopC)
})
}
func (t *TimeWheel) getSlotPosAndCycle(flower time.Time) (int, int) {
delay := time.Until(flower)
cycle := delay / (t.interval * time.Duration(len(t.slots)))
pos := (t.curtSlot + int(delay/t.interval)) % len(t.slots)
return int(cycle), pos
}
func (t *TimeWheel) run() {
for {
select {
case <-t.stopC:
return
case <-t.ticker.C:
t.tick()
case task := <-t.addTaskC:
t.addTask(task)
case key := <-t.removeTaskC:
t.removeTask(key)
}
}
}
func (t *TimeWheel) tick() {
list := t.slots[t.curtSlot]
defer t.updateCurtSlot()
t.execute(list)
}
func (t *TimeWheel) updateCurtSlot() {
t.curtSlot = (t.curtSlot + 1) % len(t.slots)
}
func (t *TimeWheel) execute(l *list.List) {
for e := l.Back(); e != nil; {
event := e.Value.(*taskSlice)
if event.cycle > 0 {
event.cycle--
e = e.Next()
continue
}
go func() {
defer func() {
if err := recover(); err != nil {
}
}()
event.task()
}()
next := e.Next()
l.Remove(e)
delete(t.taskMap, event.id)
e = next
}
}
func (t *TimeWheel) addTask(task *taskSlice) {
if _, ok := t.taskMap[task.id]; ok {
t.removeTask(task.id)
}
list := t.slots[task.pos]
etask := list.PushBack(task)
t.taskMap[task.id] = etask
}
func (t *TimeWheel) removeTask(key string) {
etask, ok := t.taskMap[key]
if !ok {
return
}
delete(t.taskMap, key)
task := etask.Value.(*taskSlice)
_ = t.slots[task.pos].Remove(etask)
}

24
timewheel_test.go Normal file
View File

@ -0,0 +1,24 @@
package timewheel
import (
"testing"
"time"
)
func Test_timeWheel(t *testing.T) {
timeWheel := New(10, 500*time.Millisecond)
defer timeWheel.Stop()
t.Errorf("test2, %v", time.Now())
timeWheel.AddTask("test1", func() {
t.Errorf("test1, %v", time.Now())
}, time.Now().Add(time.Second))
timeWheel.AddTask("test2", func() {
t.Errorf("test2, %v", time.Now())
}, time.Now().Add(5*time.Second))
timeWheel.AddTask("test2", func() {
t.Errorf("test2, %v", time.Now())
}, time.Now().Add(3*time.Second))
<-time.After(6 * time.Second)
}