package timewheel import ( "container/list" "sync" "time" ) type taskSlice struct { id string pos int cycle int task func() } type TimeWheel struct { sync.Once curtSlot int interval time.Duration ticker *time.Ticker slots []*list.List taskMap map[string]*list.Element stopC chan struct{} addTaskC chan *taskSlice removeTaskC chan string } func New(slots int, interval time.Duration) *TimeWheel { if slots <= 0 { slots = 10 } if interval <= 0 { interval = time.Second } t := &TimeWheel{ interval: interval, ticker: time.NewTicker(interval), slots: make([]*list.List, 0, slots), taskMap: map[string]*list.Element{}, stopC: make(chan struct{}), addTaskC: make(chan *taskSlice), removeTaskC: make(chan string), } for i := 0; i < slots; i++ { t.slots = append(t.slots, list.New()) } go t.run() return t } func (t *TimeWheel) AddTask(id string, task func(), flower time.Time) { cycle, pos := t.getSlotPosAndCycle(flower) t.addTaskC <- &taskSlice{ id: id, task: task, cycle: cycle, pos: pos, } } func (t *TimeWheel) RemoveTask(key string) { t.removeTaskC <- key } func (t *TimeWheel) Stop() { t.Do(func() { t.ticker.Stop() close(t.stopC) }) } func (t *TimeWheel) getSlotPosAndCycle(flower time.Time) (int, int) { delay := time.Until(flower) cycle := delay / (t.interval * time.Duration(len(t.slots))) pos := (t.curtSlot + int(delay/t.interval)) % len(t.slots) return int(cycle), pos } func (t *TimeWheel) run() { for { select { case <-t.stopC: return case <-t.ticker.C: t.tick() case task := <-t.addTaskC: t.addTask(task) case key := <-t.removeTaskC: t.removeTask(key) } } } func (t *TimeWheel) tick() { list := t.slots[t.curtSlot] defer t.updateCurtSlot() t.execute(list) } func (t *TimeWheel) updateCurtSlot() { t.curtSlot = (t.curtSlot + 1) % len(t.slots) } func (t *TimeWheel) execute(l *list.List) { for e := l.Back(); e != nil; { event := e.Value.(*taskSlice) if event.cycle > 0 { event.cycle-- e = e.Next() continue } go func() { defer func() { if err := recover(); err != nil { } }() event.task() }() next := e.Next() l.Remove(e) delete(t.taskMap, event.id) e = next } } func (t *TimeWheel) addTask(task *taskSlice) { if _, ok := t.taskMap[task.id]; ok { t.removeTask(task.id) } list := t.slots[task.pos] etask := list.PushBack(task) t.taskMap[task.id] = etask } func (t *TimeWheel) removeTask(key string) { etask, ok := t.taskMap[key] if !ok { return } delete(t.taskMap, key) task := etask.Value.(*taskSlice) _ = t.slots[task.pos].Remove(etask) }