From 0e44d3aa4494e609cacdcf374dd9e8d01d597a05 Mon Sep 17 00:00:00 2001 From: redhat <2292650292@qq.com> Date: Tue, 10 Jun 2025 19:53:35 +0800 Subject: [PATCH] first commit --- README.md | 0 fsm.go | 284 ++++++++++++++++++++++++++++++++++++++++++++++++++++ fsm_test.go | 95 ++++++++++++++++++ go.mod | 3 + 4 files changed, 382 insertions(+) create mode 100644 README.md create mode 100644 fsm.go create mode 100644 fsm_test.go create mode 100644 go.mod diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/fsm.go b/fsm.go new file mode 100644 index 0000000..59832e0 --- /dev/null +++ b/fsm.go @@ -0,0 +1,284 @@ +package fsm + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "git.zhangshuocauc.cn/redhat/timewheel" +) + +type EventName string + +func (e EventName) Error() string { + return string(e) +} + +const ( + EventEntry EventName = "EventEntry" + EventExit EventName = "EventExit" + EventTimeOut EventName = "EventTimeOut" + + EventOK EventName = "EventOK" + EventNoProc EventName = "EventNoProc" +) + +var internalEvents = Events{ + EventEntry, + EventExit, + EventTimeOut, + EventOK, + EventNoProc, +} + +type StateFunc func(context.Context, *Event) error + +type StateRule struct { + Name string + Parent string + InitState string + Processor StateFunc + Dst []string // Define rules that allow conversion, allowing null values +} + +type Event struct { + FSM *FSM + Event EventName + Src string + Dst string + Args []interface{} +} + +type FSM struct { + eventMu sync.Mutex + log *log.Logger + name string + currentState *StateRule + eventNames Events + fsmStates StateRules + timer *timewheel.TimeWheel // Use a time wheel timer with a default slot count of 10 and a cycle time of 100ms +} + +type StateRules []*StateRule + +func (s StateRules) findByName(name string) *StateRule { + if name == "" { + return nil + } + + for _, sub := range s { + if sub.Name == name { + return sub + } + } + + return nil +} + +func (s StateRules) checkRule() bool { + for _, state := range s { + for _, des := range state.Dst { + if des != "" && s.findByName(des) == nil { + return false + } + } + + if state.Parent != "" && s.findByName(state.Parent) == nil { + return false + } + + if state.InitState != "" && s.findByName(state.InitState) == nil { + return false + } + } + + return true +} + +type Events []EventName + +func (e Events) findByName(name EventName) EventName { + for _, sub := range e { + if sub == name { + return sub + } + } + + return "" +} + +func NewFsm(name string, initState string, events Events, states StateRules, logs ...*log.Logger) (*FSM, error) { + if len(events) == 0 || len(states) == 0 { + return nil, fmt.Errorf("initial state cannot be nil") + } + + // Verify if all conversion rules are correct + if !states.checkRule() { + return nil, fmt.Errorf("state rules error") + } + + fsm := &FSM{ + name: name, + currentState: states.findByName(initState), + eventNames: events, + fsmStates: states, + timer: timewheel.New(10, time.Millisecond*100), + } + + fsm.eventNames = append(fsm.eventNames, internalEvents...) + if len(logs) == 1 { + fsm.log = logs[0] + } + if fsm.log == nil { + fsm.log = log.Default() + } + + // Execute ENTRY event for initial state and its sub-states + for tempState := fsm.currentState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.InitState) { + fsm.currentState = tempState + if err := fsm.ExecuteEvent(context.TODO(), EventEntry, nil); err != nil { + return nil, err + } + } + + return fsm, nil +} + +func (fsm *FSM) StartEventTimer(ctx context.Context, typeT timewheel.TimeMode, cycle time.Duration, arg interface{}) { + oldState := fsm.currentState.Name + fsm.log.Printf("Start event timer state: %s\n", oldState) + fsm.timer.AddTask(fsm.currentState.Name, typeT, func() { + // Verify whether the current state has changed + // If there is a change, timeout events cannot be executed anymore and the timer for the status needs to be deleted + if fsm.currentState.Name != oldState { + fsm.log.Printf("Event timer but state alrady change (%s -> %s)\n", oldState, fsm.currentState.Name) + fsm.timer.RemoveTask(oldState) + + return + } + fsm.ExecuteEvent(ctx, EventTimeOut, arg) + }, cycle) +} + +func (fsm *FSM) StopEventTimer() { + fsm.timer.RemoveTask(fsm.currentState.Name) +} + +func (fsm *FSM) StopTimer() { + fsm.timer.Stop() +} + +// Because the timer timeout event is triggered by other goroutines, the execution of the event requires a lock +func (fsm *FSM) ExecuteEvent(ctx context.Context, event EventName, arg ...interface{}) error { + fsm.eventMu.Lock() + defer fsm.eventMu.Unlock() + + if fsm.eventNames.findByName(event) == "" { + return fmt.Errorf("new state cannot be nil") + } + + fsm.log.Printf("%s %-23s STATE:%-23s EVENT:%s\n", time.Now().Format(".999"), fsm.name, fsm.currentState.Name, event) + + tmpState := fsm.currentState + if tmpState == nil { + return fmt.Errorf("current state is nil") + } + + if tmpState.Processor != nil { + e := &Event{FSM: fsm, Event: event, Src: tmpState.Name, Dst: tmpState.Name, Args: arg} + result := tmpState.Processor(ctx, e) + + for errors.Is(result, EventNoProc) && fsm.fsmStates.findByName(tmpState.Parent) != nil { + tmpState = fsm.fsmStates.findByName(tmpState.Parent) + result = tmpState.Processor(ctx, e) + } + + if result != nil && !errors.Is(result, EventNoProc) && !errors.Is(result, EventOK) { + return result + } + } + + return nil +} + +func (fsm *FSM) StateChange(ctx context.Context, newState string, arg interface{}) error { + if newState == "" || fsm.fsmStates.findByName(newState) == nil { + return fmt.Errorf("new state cannot be nil") + } + + // If there are defined rules, they need to be checked. If they are empty, it means that any transition between states is allowed + if !fsm.checkDstRule(newState) { + return fmt.Errorf("the new state is not in dst rule") + } + + for tempState := fsm.currentState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.Parent) { + fsm.currentState = tempState + if err := fsm.ExecuteEvent(ctx, EventExit, arg); err != nil { + return err + } + fsm.StopEventTimer() + + if newState == tempState.Name { + if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { + return err + } + break + } + + if found, ok := fsm.findState(fsm.fsmStates.findByName(newState), tempState); ok { + for end := len(found) - 1; end >= 0; end-- { + fsm.currentState = found[end] + if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { + return err + } + } + break + } + } + + // Enter sub-states of the new state + for tempState := fsm.fsmStates.findByName(newState).InitState; fsm.fsmStates.findByName(tempState) != nil; tempState = fsm.fsmStates.findByName(tempState).InitState { + fsm.currentState = fsm.fsmStates.findByName(tempState) + if err := fsm.ExecuteEvent(ctx, EventEntry, arg); err != nil { + return err + } + } + + return nil +} + +func (fsm *FSM) findState(targetState, findState *StateRule) ([]*StateRule, bool) { + if targetState == nil || findState == nil { + return nil, false + } + + var fsmStateList []*StateRule + + for tempState := targetState; tempState != nil; tempState = fsm.fsmStates.findByName(tempState.Parent) { + fsmStateList = append(fsmStateList, tempState) + if findState.Parent == tempState.Parent { + return fsmStateList, true + } + } + + return fsmStateList, false +} + +func (fsm *FSM) checkDstRule(newState string) bool { + curState := fsm.currentState + + if len(curState.Dst) > 0 { + for _, dst := range curState.Dst { + if dst == "" || dst == newState { + return true + } + } + + return false + } + + return true +} diff --git a/fsm_test.go b/fsm_test.go new file mode 100644 index 0000000..2c3e293 --- /dev/null +++ b/fsm_test.go @@ -0,0 +1,95 @@ +package fsm + +import ( + "context" + "fmt" + "log" + "testing" + "time" + + "git.zhangshuocauc.cn/redhat/timewheel" +) + +func state1(ctx context.Context, e *Event) error { + switch e.Event { + case EventEntry: + log.Println("state1", *e) + case EventExit: + log.Println("state1", *e) + default: + log.Println("state1", *e) + return EventNoProc + } + + return EventOK +} + +func state2(ctx context.Context, e *Event) error { + switch e.Event { + case EventEntry: + log.Println("state2", *e) + case EventExit: + log.Println("state2", *e) + default: + log.Println("state2", *e) + return EventNoProc + } + + return EventOK +} + +func state3(ctx context.Context, e *Event) error { + switch e.Event { + case EventEntry: + log.Println("state3", *e) + case EventExit: + log.Println("state3", *e) + case "assign": + log.Println("state3", *e) + default: + log.Println("state3", *e) + return EventNoProc + } + + return EventOK +} + +var events = Events{ + "assign", + "accept", +} + +func Test_fsm(t *testing.T) { + fsm, err := NewFsm("MyFsm", "state1", events, []*StateRule{ + {Name: "state1", Parent: "", InitState: "state2", Processor: state1, Dst: []string{}}, + {Name: "state2", Parent: "state1", InitState: "", Processor: state2, Dst: []string{}}, + {Name: "state3", Parent: "state1", InitState: "", Processor: state3, Dst: []string{}}, + {}, + }) + + if err != nil { + fmt.Println(err) + } + + defer fsm.StopTimer() + + ctx := context.Background() + + fmt.Println(fsm) + + fsm.ExecuteEvent(ctx, EventName("assign"), nil) + + fmt.Println(fsm.StateChange(ctx, "state3", nil)) + + fsm.ExecuteEvent(ctx, EventName("accept"), nil) + + fsm.ExecuteEvent(ctx, EventName("accept1"), nil) + + fsm.StartEventTimer(ctx, timewheel.TimeTypeLoop, time.Millisecond*100, nil) + + <-time.After(time.Millisecond*100) + + fmt.Println(fsm.StateChange(ctx, "state2", nil)) + + <-time.After(3 * time.Second) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c49483e --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.zhangshuocauc.cn/redhat/fsm + +go 1.24.3