commit 37491afb03f3fcc2f10ee7510105fcfa81ef44c4 Author: redhat <2292650292@qq.com> Date: Mon Jun 9 19:54:37 2025 +0800 1、init diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..39b3745 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.zhangshuocauc.cn/redhat/timewheel + +go 1.24.3 diff --git a/timewheel.go b/timewheel.go new file mode 100644 index 0000000..65b669b --- /dev/null +++ b/timewheel.go @@ -0,0 +1,158 @@ +package timewheel + +import ( + "container/list" + "sync" + "time" +) + +type taskSlice struct { + id string + pos int + cycle int + task func() +} + +type TimeWheel struct { + sync.Once + curtSlot int + interval time.Duration + ticker *time.Ticker + slots []*list.List + taskMap map[string]*list.Element + stopC chan struct{} + addTaskC chan *taskSlice + removeTaskC chan string +} + +func New(slots int, interval time.Duration) *TimeWheel { + if slots <= 0 { + slots = 10 + } + + if interval <= 0 { + interval = time.Second + } + + t := &TimeWheel{ + interval: interval, + ticker: time.NewTicker(interval), + slots: make([]*list.List, 0, slots), + taskMap: map[string]*list.Element{}, + stopC: make(chan struct{}), + addTaskC: make(chan *taskSlice), + removeTaskC: make(chan string), + } + + for i := 0; i < slots; i++ { + t.slots = append(t.slots, list.New()) + } + + go t.run() + + return t +} + +func (t *TimeWheel) AddTask(id string, task func(), flower time.Time) { + cycle, pos := t.getSlotPosAndCycle(flower) + + t.addTaskC <- &taskSlice{ + id: id, + task: task, + cycle: cycle, + pos: pos, + } +} + +func (t *TimeWheel) RemoveTask(key string) { + t.removeTaskC <- key +} + +func (t *TimeWheel) Stop() { + t.Do(func() { + t.ticker.Stop() + close(t.stopC) + }) +} + +func (t *TimeWheel) getSlotPosAndCycle(flower time.Time) (int, int) { + delay := time.Until(flower) + cycle := delay / (t.interval * time.Duration(len(t.slots))) + pos := (t.curtSlot + int(delay/t.interval)) % len(t.slots) + + return int(cycle), pos +} + +func (t *TimeWheel) run() { + for { + select { + case <-t.stopC: + return + case <-t.ticker.C: + t.tick() + case task := <-t.addTaskC: + t.addTask(task) + case key := <-t.removeTaskC: + t.removeTask(key) + } + } +} + +func (t *TimeWheel) tick() { + list := t.slots[t.curtSlot] + defer t.updateCurtSlot() + + t.execute(list) +} + +func (t *TimeWheel) updateCurtSlot() { + t.curtSlot = (t.curtSlot + 1) % len(t.slots) +} + +func (t *TimeWheel) execute(l *list.List) { + for e := l.Back(); e != nil; { + event := e.Value.(*taskSlice) + if event.cycle > 0 { + event.cycle-- + e = e.Next() + continue + } + + go func() { + defer func() { + if err := recover(); err != nil { + + } + }() + + event.task() + }() + + next := e.Next() + l.Remove(e) + delete(t.taskMap, event.id) + e = next + } +} + +func (t *TimeWheel) addTask(task *taskSlice) { + if _, ok := t.taskMap[task.id]; ok { + t.removeTask(task.id) + } + + list := t.slots[task.pos] + etask := list.PushBack(task) + t.taskMap[task.id] = etask +} + +func (t *TimeWheel) removeTask(key string) { + etask, ok := t.taskMap[key] + if !ok { + return + } + + delete(t.taskMap, key) + + task := etask.Value.(*taskSlice) + _ = t.slots[task.pos].Remove(etask) +} diff --git a/timewheel_test.go b/timewheel_test.go new file mode 100644 index 0000000..fc43921 --- /dev/null +++ b/timewheel_test.go @@ -0,0 +1,24 @@ +package timewheel + +import ( + "testing" + "time" +) + +func Test_timeWheel(t *testing.T) { + timeWheel := New(10, 500*time.Millisecond) + defer timeWheel.Stop() + + t.Errorf("test2, %v", time.Now()) + timeWheel.AddTask("test1", func() { + t.Errorf("test1, %v", time.Now()) + }, time.Now().Add(time.Second)) + timeWheel.AddTask("test2", func() { + t.Errorf("test2, %v", time.Now()) + }, time.Now().Add(5*time.Second)) + timeWheel.AddTask("test2", func() { + t.Errorf("test2, %v", time.Now()) + }, time.Now().Add(3*time.Second)) + + <-time.After(6 * time.Second) +}