package fsm import ( "context" "errors" "fmt" "log" "sync" "time" "git.zhangshuocauc.cn/redhat/timewheel" ) type EventName string func (e EventName) Error() string { return string(e) } const ( EventEntry EventName = "EventEntry" EventExit EventName = "EventExit" EventTimeOut EventName = "EventTimeOut" EventOK EventName = "EventOK" EventNoProc EventName = "EventNoProc" ) var internalEvents = Events{ EventEntry, EventExit, EventTimeOut, EventOK, EventNoProc, } type StateFunc func(context.Context, *Event) error type StateRule struct { Name string Parent string InitState string Processor StateFunc Dst []string // Define rules that allow conversion, allowing null values } type Event struct { FSM *FSM Event EventName Src string Dst string Args []interface{} } type FSM struct { eventMu sync.Mutex log *log.Logger name string currentState *StateRule eventNames Events fsmStates StateRules timer *timewheel.TimeWheel // Use a time wheel timer with a default slot count of 10 and a cycle time of 100ms } type StateRules []*StateRule func (s StateRules) findByName(name string) *StateRule { if name == "" { return nil } for _, sub := range s { if sub.Name == name { return sub } } return nil } func (s StateRules) checkRule() bool { for _, state := range s { for _, des := range state.Dst { if des != "" && s.findByName(des) == nil { return false } } if state.Parent != "" && s.findByName(state.Parent) == nil { return false } if state.InitState != "" && s.findByName(state.InitState) == nil { return false } } return true } type Events []EventName func (e Events) findByName(name EventName) EventName { for _, sub := range e { if sub == name { return sub } } return "" } func NewFsm(name string, initState string, events Events, states StateRules, logs ...*log.Logger) (*FSM, error) { if len(events) == 0 || len(states) == 0 { return nil, fmt.Errorf("initial state cannot be nil") } // Verify if all conversion rules are correct if !states.checkRule() { return nil, fmt.Errorf("state rules error") } fsm := &FSM{ name: name, currentState: states.findByName(initState), eventNames: events, fsmStates: states, timer: timewheel.New(10, time.Millisecond*100), } fsm.eventNames = append(fsm.eventNames, internalEvents...) if len(logs) == 1 { fsm.log = logs[0] } if fsm.log == nil { fsm.log = log.Default() } // Execute ENTRY event for initial state and its sub-states for tempState := fsm.currentState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.InitState) { fsm.currentState = tempState if err := fsm.ExecuteEvent(context.TODO(), EventEntry, nil); err != nil { return nil, err } } return fsm, nil } func (fsm *FSM) StartEventTimer(ctx context.Context, typeT timewheel.TimeMode, cycle time.Duration, arg interface{}) { oldState := fsm.currentState.Name fsm.log.Printf("Start event timer state: %s\n", oldState) fsm.timer.AddTask(fsm.currentState.Name, typeT, func() { // Verify whether the current state has changed // If there is a change, timeout events cannot be executed anymore and the timer for the status needs to be deleted if fsm.currentState.Name != oldState { fsm.log.Printf("Event timer but state alrady change (%s -> %s)\n", oldState, fsm.currentState.Name) fsm.timer.RemoveTask(oldState) return } fsm.ExecuteEvent(ctx, EventTimeOut, arg) }, cycle) } func (fsm *FSM) StopEventTimer() { fsm.timer.RemoveTask(fsm.currentState.Name) } func (fsm *FSM) StopTimer() { fsm.timer.Stop() } // Because the timer timeout event is triggered by other goroutines, the execution of the event requires a lock func (fsm *FSM) ExecuteEvent(ctx context.Context, event EventName, arg ...interface{}) error { fsm.eventMu.Lock() defer fsm.eventMu.Unlock() if fsm.eventNames.findByName(event) == "" { return fmt.Errorf("new state cannot be nil") } fsm.log.Printf("%s %-23s STATE:%-23s EVENT:%s\n", time.Now().Format(".999"), fsm.name, fsm.currentState.Name, event) tmpState := fsm.currentState if tmpState == nil { return fmt.Errorf("current state is nil") } if tmpState.Processor != nil { e := &Event{FSM: fsm, Event: event, Src: tmpState.Name, Dst: tmpState.Name, Args: arg} result := tmpState.Processor(ctx, e) for errors.Is(result, EventNoProc) && fsm.fsmStates.findByName(tmpState.Parent) != nil { tmpState = fsm.fsmStates.findByName(tmpState.Parent) result = tmpState.Processor(ctx, e) } if result != nil && !errors.Is(result, EventNoProc) && !errors.Is(result, EventOK) { return result } } return nil } func (fsm *FSM) StateChange(ctx context.Context, newState string, arg interface{}) error { if newState == "" || fsm.fsmStates.findByName(newState) == nil { return fmt.Errorf("new state cannot be nil") } // If there are defined rules, they need to be checked. If they are empty, it means that any transition between states is allowed if !fsm.checkDstRule(newState) { return fmt.Errorf("the new state is not in dst rule") } for tempState := fsm.currentState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.Parent) { fsm.currentState = tempState if err := fsm.ExecuteEvent(ctx, EventExit, arg); err != nil { return err } fsm.StopEventTimer() if newState == tempState.Name { if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { return err } break } if found, ok := fsm.findState(fsm.fsmStates.findByName(newState), tempState); ok { for end := len(found) - 1; end >= 0; end-- { fsm.currentState = found[end] if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { return err } } break } } // Enter sub-states of the new state for tempState := fsm.fsmStates.findByName(newState).InitState; fsm.fsmStates.findByName(tempState) != nil; tempState = fsm.fsmStates.findByName(tempState).InitState { fsm.currentState = fsm.fsmStates.findByName(tempState) if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { return err } } return nil } func (fsm *FSM) findState(targetState, findState *StateRule) ([]*StateRule, bool) { if targetState == nil || findState == nil { return nil, false } var fsmStateList []*StateRule for tempState := targetState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.Parent) { fsmStateList = append(fsmStateList, tempState) if findState.Parent == tempState.Parent { return fsmStateList, true } } return fsmStateList, false } func (fsm *FSM) checkDstRule(newState string) bool { curState := fsm.currentState if len(curState.Dst) > 0 { for _, dst := range curState.Dst { if dst == "" || dst == newState { return true } } return false } return true }