package timewheel import ( "container/list" "sync" "time" ) type taskSlice struct { id string pos int cycle int rawCycle int mode TimeMode task func() } type TimeWheel struct { sync.Once curtSlot int slotLen int interval time.Duration ticker *time.Ticker slots []*list.List tinyWheel *list.List taskMap map[string]*list.Element stopC chan struct{} addTaskC chan *taskSlice removeTaskC chan string } type TimeMode int const ( TimerTypeOnce TimeMode = iota TimeTypeLoop ) func New(slots int, interval time.Duration) *TimeWheel { if slots <= 0 { slots = 10 } if interval <= 0 { interval = time.Second } t := &TimeWheel{ interval: interval, ticker: time.NewTicker(interval), slots: make([]*list.List, 0, slots), taskMap: make(map[string]*list.Element), stopC: make(chan struct{}), addTaskC: make(chan *taskSlice), removeTaskC: make(chan string), tinyWheel: list.New(), slotLen: slots, } for i := 0; i < slots; i++ { t.slots = append(t.slots, list.New()) } go t.run() return t } func (t *TimeWheel) AddTask(id string, mode TimeMode, task func(), flower time.Duration) { cycle, pos := t.getSlotPosAndCycle(flower) t.addTaskC <- &taskSlice{ id: id, task: task, cycle: cycle, mode: mode, rawCycle: cycle, pos: pos, } } func (t *TimeWheel) RemoveTask(key string) { t.removeTaskC <- key } func (t *TimeWheel) Stop() { t.Do(func() { t.ticker.Stop() close(t.stopC) }) } func (t *TimeWheel) getSlotPosAndCycle(flower time.Duration) (int, int) { delay := time.Until(time.Now().Add(flower)) cycle := delay / (t.interval * time.Duration(t.slotLen)) pos := (t.curtSlot + int(delay/t.interval)) % t.slotLen return int(cycle), pos } func (t *TimeWheel) run() { for { select { case <-t.stopC: return case <-t.ticker.C: t.tick() case task := <-t.addTaskC: t.addTask(task) case key := <-t.removeTaskC: t.removeTask(key) } } } func (t *TimeWheel) tick() { list := t.slots[t.curtSlot] defer t.updateCurtSlot() t.execute(list) t.execute(t.tinyWheel) } func (t *TimeWheel) updateCurtSlot() { t.curtSlot = (t.curtSlot + 1) % t.slotLen } func (t *TimeWheel) execute(l *list.List) { for e := l.Front(); e != nil; { event := e.Value.(*taskSlice) if event.cycle > 0 { event.cycle-- e = e.Next() continue } go func() { defer func() { if err := recover(); err != nil { } }() event.task() }() if event.mode == TimerTypeOnce { next := e.Next() l.Remove(e) delete(t.taskMap, event.id) e = next } else { e = e.Next() event.cycle = event.rawCycle } } } func (t *TimeWheel) addTask(task *taskSlice) { if _, ok := t.taskMap[task.id]; ok { t.removeTask(task.id) } var list *list.List if task.mode == TimeTypeLoop { task.cycle = task.cycle*t.slotLen + (task.pos-t.curtSlot+t.slotLen)%t.slotLen task.rawCycle = task.cycle list = t.tinyWheel } else { list = t.slots[task.pos] } etask := list.PushBack(task) t.taskMap[task.id] = etask } func (t *TimeWheel) removeTask(key string) { etask, ok := t.taskMap[key] if !ok { return } delete(t.taskMap, key) task := etask.Value.(*taskSlice) if task.mode == TimeTypeLoop { _ = t.tinyWheel.Remove(etask) } else { _ = t.slots[task.pos].Remove(etask) } }