commit 2891687b72a204f39fe4895b6e7e03f3f034dcf1
Author: redhat <2292650292@qq.com>
Date: Wed May 7 10:34:10 2025 +0800
first commit
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/consistencehash.iml b/.idea/consistencehash.iml
new file mode 100644
index 0000000..5e764c4
--- /dev/null
+++ b/.idea/consistencehash.iml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..d9ca9e7
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..d843f34
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/consistencehash.go b/consistencehash.go
new file mode 100644
index 0000000..79d4892
--- /dev/null
+++ b/consistencehash.go
@@ -0,0 +1,261 @@
+package consistencehash
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+)
+
+type ConsistenceHash struct {
+ hashRing HashRing
+ encrytor Encrytor
+ migrator Migrator
+ opts Options
+}
+
+func NewConsistenceHash(hashRing HashRing, encrytor Encrytor, migrator Migrator, opts ...Option) *ConsistenceHash {
+ chash := &ConsistenceHash{
+ hashRing: hashRing,
+ encrytor: encrytor,
+ migrator: migrator,
+ }
+
+ for _, o := range opts {
+ o(&chash.opts)
+ }
+
+ chash.opts.repire()
+
+ return chash
+}
+
+func (c *ConsistenceHash) getVirNodeId(id string, index int) string {
+ return fmt.Sprintf("%s_%d", id, index)
+}
+
+func (c *ConsistenceHash) getRawNodeId(id string) string {
+ index := strings.LastIndex(id, "_")
+ return id[:index]
+}
+
+func (c *ConsistenceHash) migratorHandle(funs []func()) {
+ wg := sync.WaitGroup{}
+ for _, fn := range funs {
+ wg.Add(1)
+
+ go func(fn func()) {
+ defer func() {
+ if err := recover(); err != nil {
+ fmt.Println("recover")
+ }
+ wg.Done()
+ }()
+ fn()
+ }(fn)
+ }
+ wg.Wait()
+}
+
+func (c *ConsistenceHash) wight(wight int) int {
+ if wight < 1 {
+ return 1
+ }
+
+ if wight > 10 {
+ return 10
+ }
+
+ return wight
+}
+
+func (c *ConsistenceHash) AddNode(ctx context.Context, nodeId string, wight int) error {
+ if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
+ return err
+ }
+ defer func() {
+ _ = c.hashRing.UnLock(ctx)
+ }()
+
+ nodes, err := c.hashRing.Nodes(ctx)
+ if err != nil {
+ return err
+ }
+
+ for node := range nodes {
+ if node == nodeId {
+ return errors.New("node already exists")
+ }
+ }
+
+ count := c.wight(wight) * c.opts.replicas
+
+ if err := c.hashRing.AddNodes(ctx, nodeId, count); err != nil {
+ return err
+ }
+
+ var migratorFns []func()
+ for i := 0; i < count; i++ {
+ vNodeId := c.getVirNodeId(nodeId, i)
+ sum := c.encrytor.Hash(vNodeId)
+
+ if err := c.hashRing.Add(ctx, sum, vNodeId); err != nil {
+ return err
+ }
+
+ datas, from, to, err := c.migratorIn(ctx, sum, vNodeId)
+ if err != nil {
+ return err
+ }
+
+ if len(datas) == 0 || from == "" {
+ continue
+ }
+
+ migratorFns = append(migratorFns, func() {
+ _ = c.migrator(ctx, datas, from, to)
+ })
+ }
+
+ c.migratorHandle(migratorFns)
+
+ return nil
+}
+
+func (c *ConsistenceHash) getAllDatas(ctx context.Context, nodeId string, counts int) (map[string]struct{}, error) {
+ var allDatas map[string]struct{}
+
+ for i := 0; i < counts; i++ {
+ vNodeId := c.getVirNodeId(nodeId, i)
+ _datas, err := c.hashRing.Datas(ctx, vNodeId)
+ if err != nil {
+ return allDatas, err
+ }
+
+ if allDatas == nil {
+ allDatas = make(map[string]struct{})
+ }
+
+ for key := range _datas {
+ allDatas[key] = struct{}{}
+ }
+ }
+
+ return allDatas, nil
+}
+
+func (c *ConsistenceHash) DelNode(ctx context.Context, nodeId string) error {
+ if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
+ return err
+ }
+ defer func() {
+ _ = c.hashRing.UnLock(ctx)
+ }()
+
+ nodes, err := c.hashRing.Nodes(ctx)
+ if err != nil {
+ return err
+ }
+
+ var (
+ exists bool
+ counts int
+ )
+
+ for node, nums := range nodes {
+ if node == nodeId {
+ exists = true
+ counts = nums
+
+ break
+ }
+ }
+
+ if !exists {
+ return errors.New("node id is not exists")
+ }
+
+ // _datas, err := c.hashRing.Datas(ctx, nodeId)
+ _datas, err := c.getAllDatas(ctx, nodeId, counts)
+ if err != nil {
+ return err
+ }
+
+ if len(_datas) == 0 {
+ return nil
+ }
+
+ var migratorFns []func()
+ for i := 0; i < counts; i++ {
+ vNodeId := c.getVirNodeId(nodeId, i)
+ sum := c.encrytor.Hash(vNodeId)
+
+ datas, from, to, err := c.migratorOut(ctx, sum, vNodeId, &_datas)
+ if err != nil {
+ return err
+ }
+
+ if err := c.hashRing.Del(ctx, sum, vNodeId); err != nil {
+ return err
+ }
+
+ if len(datas) == 0 || to == "" {
+ continue
+ }
+
+ migratorFns = append(migratorFns, func() {
+ _ = c.migrator(ctx, datas, from, to)
+ })
+ }
+
+ if len(_datas) != 0 {
+ fmt.Println(_datas)
+ return errors.New("internal error")
+ }
+
+ if err := c.hashRing.DelNodes(ctx, nodeId); err != nil {
+ return err
+ }
+
+ c.migratorHandle(migratorFns)
+
+ return nil
+}
+
+func (c *ConsistenceHash) GetNode(ctx context.Context, key string) (string, error) {
+ if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
+ return "", err
+ }
+ defer func() {
+ _ = c.hashRing.UnLock(ctx)
+ }()
+
+ keysum := c.encrytor.Hash(key)
+
+ nextScore, err := c.hashRing.Ceiling(ctx, keysum)
+ if err != nil {
+ return "", err
+ }
+
+ if nextScore == -1 {
+ return "", errors.New("no avalible node")
+ }
+
+ nextNode, err := c.hashRing.Node(ctx, nextScore)
+ if err != nil {
+ return "", err
+ }
+
+ if len(nextNode) == 0 {
+ return "", errors.New("node id error")
+ }
+
+ if err := c.hashRing.AddDatas(ctx, nextNode[0], map[string]struct{}{
+ key: {},
+ }); err != nil {
+ return "", err
+ }
+
+ return c.getRawNodeId(nextNode[0]), nil
+}
diff --git a/consistencehash_test.go b/consistencehash_test.go
new file mode 100644
index 0000000..03c2c40
--- /dev/null
+++ b/consistencehash_test.go
@@ -0,0 +1,149 @@
+package consistencehash
+
+import (
+ "context"
+ "fmt"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/local"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/redis/hashring"
+ "github.com/redis/go-redis/v9"
+ "hash/fnv"
+ "testing"
+)
+
+func migrator(ctx context.Context, datas map[string]struct{}, from, to string) error {
+ fmt.Printf("from: %s, to: %s, datas: %v", from, to, datas)
+ return nil
+}
+
+func TestConHash_Local(t *testing.T) {
+ hashRing := local.NewHashRing(&local.BlockLock{})
+ encrytor := NewEncrytorImp(fnv.New32())
+
+ hash := NewConsistenceHash(
+ hashRing,
+ encrytor,
+ migrator,
+ WithReplicas(10),
+ WithLockExpire(10),
+ )
+
+ test(t, hash)
+}
+
+func TestConHash_Redis(t *testing.T) {
+ client := redis.NewClient(&redis.Options{
+ Addr: "192.168.8.1:6379",
+ Password: "",
+ })
+
+ lockwap := hashring.NewRedisLockWrap(client, "redhat_redis_lock")
+ hashRing := hashring.NewRedisHashRing(client, "redhat_redis_hashring", lockwap)
+ encrytor := NewEncrytorImp(fnv.New32())
+
+ hash := NewConsistenceHash(
+ hashRing,
+ encrytor,
+ migrator,
+ WithReplicas(10),
+ WithLockExpire(10),
+ )
+
+ test(t, hash)
+}
+
+func test(t *testing.T, consistentHash *ConsistenceHash) {
+ ctx := context.Background()
+ nodeA := "node_a"
+ weightNodeA := 2
+ nodeB := "node_b"
+ weightNodeB := 1
+ nodeC := "node_c"
+ weightNodeC := 1
+
+ if err := consistentHash.AddNode(ctx, nodeA, weightNodeA); err != nil {
+ t.Error(err)
+ return
+ }
+
+ if err := consistentHash.AddNode(ctx, nodeB, weightNodeB); err != nil {
+ t.Error(err)
+ return
+ }
+
+ dataKeyA := "data_a"
+ dataKeyB := "data_b"
+ dataKeyC := "data_c"
+ dataKeyD := "data_d"
+ node, err := consistentHash.GetNode(ctx, dataKeyA)
+ fmt.Println(node)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyA, node)
+
+ if node, err = consistentHash.GetNode(ctx, dataKeyB); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyB, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyC); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyC, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyD); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyD, node)
+ if err := consistentHash.AddNode(ctx, nodeC, weightNodeC); err != nil {
+ t.Error(err)
+ return
+ }
+ if node, err = consistentHash.GetNode(ctx, dataKeyA); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyA, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyB); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyB, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyC); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyC, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyD); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyD, node)
+ if err = consistentHash.DelNode(ctx, nodeC); err != nil {
+ t.Error(err)
+ return
+ }
+ if node, err = consistentHash.GetNode(ctx, dataKeyA); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyA, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyB); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyB, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyC); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyC, node)
+ if node, err = consistentHash.GetNode(ctx, dataKeyD); err != nil {
+ t.Error(err)
+ return
+ }
+ t.Logf("data: %s belongs to node: %s", dataKeyD, node)
+ t.Error("ok")
+}
diff --git a/encrytor.go b/encrytor.go
new file mode 100644
index 0000000..83f45d8
--- /dev/null
+++ b/encrytor.go
@@ -0,0 +1,32 @@
+package consistencehash
+
+import (
+ "hash"
+ "math"
+
+ "github.com/spaolacci/murmur3"
+)
+
+type Encrytor interface {
+ Hash(string) int32
+}
+
+type EncrytorImp struct {
+ hash hash.Hash32
+}
+
+func NewEncrytorImp(hash hash.Hash32) *EncrytorImp {
+ return &EncrytorImp{
+ hash: hash,
+ }
+}
+
+func (e *EncrytorImp) Hash(id string) int32 {
+ e.hash = murmur3.New32()
+ // e.hash.Reset()
+ e.hash.Write([]byte(id))
+
+ return int32(e.hash.Sum32() % math.MaxInt32)
+}
+
+var _ Encrytor = &EncrytorImp{}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..18364b6
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,13 @@
+module git.zhangshuocauc.cn/redhat/goconsistencehash
+
+go 1.24.2
+
+require (
+ github.com/redis/go-redis/v9 v9.8.0
+ github.com/spaolacci/murmur3 v1.1.0
+)
+
+require (
+ github.com/cespare/xxhash/v2 v2.3.0 // indirect
+ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..db0b4ca
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,12 @@
+github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
+github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
+github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
+github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
+github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
+github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
+github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
+github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
diff --git a/hashring.go b/hashring.go
new file mode 100644
index 0000000..62eb7c5
--- /dev/null
+++ b/hashring.go
@@ -0,0 +1,21 @@
+package consistencehash
+
+import "context"
+
+type HashRing interface {
+ Lock(context.Context, int) error
+ UnLock(context.Context) error
+ Add(context.Context, int32, string) error
+ Del(context.Context, int32, string) error
+ Floor(context.Context, int32) (int32, error)
+ Ceiling(context.Context, int32) (int32, error)
+ Node(context.Context, int32) ([]string, error)
+
+ Nodes(context.Context) (map[string]int, error)
+ AddNodes(context.Context, string, int) error
+ DelNodes(context.Context, string) error
+
+ Datas(context.Context, string) (map[string]struct{}, error)
+ AddDatas(context.Context, string, map[string]struct{}) error
+ DelDatas(context.Context, string, map[string]struct{}) error
+}
diff --git a/local/blocklock.go b/local/blocklock.go
new file mode 100644
index 0000000..9f46b0c
--- /dev/null
+++ b/local/blocklock.go
@@ -0,0 +1,64 @@
+package local
+
+import (
+ "context"
+ "errors"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/local/os"
+ "sync"
+ "time"
+)
+
+type BlockLock struct {
+ mu sync.Mutex
+ dmu sync.Mutex
+ token string
+ calFn context.CancelFunc
+}
+
+func (b *BlockLock) Lock(ctx context.Context, expire int) error {
+ b.mu.Lock()
+
+ b.dmu.Lock()
+ defer b.dmu.Unlock()
+
+ token := os.GetCurrentProcessAndGoroutineIDStr()
+ b.token = token
+ b.calFn = nil
+
+ var sctx context.Context
+ sctx, b.calFn = context.WithCancel(ctx)
+ go func() {
+ select {
+ case <-sctx.Done():
+ return
+ case <-time.After(time.Duration(expire) * time.Second):
+ _ = b.unLock(token)
+ }
+ }()
+
+ return nil
+}
+
+func (b *BlockLock) unLock(id string) error {
+ b.dmu.Lock()
+ defer b.dmu.Unlock()
+
+ if b.token != id {
+ return errors.New("not my lock")
+ }
+
+ if b.calFn != nil {
+ b.calFn()
+ }
+
+ b.token = ""
+
+ b.mu.Unlock()
+
+ return nil
+}
+
+func (b *BlockLock) UnLock(context.Context) error {
+ id := os.GetCurrentProcessAndGoroutineIDStr()
+ return b.unLock(id)
+}
diff --git a/local/hashring.go b/local/hashring.go
new file mode 100644
index 0000000..dad1b7c
--- /dev/null
+++ b/local/hashring.go
@@ -0,0 +1,298 @@
+package local
+
+import (
+ "context"
+ "errors"
+ "math/rand"
+ "time"
+)
+
+type HashRing struct {
+ Locker
+ head *node
+
+ nodes map[string]int
+ datas map[string]map[string]struct{}
+
+ rand *rand.Rand
+}
+
+type node struct {
+ nexts []*node
+ score int32
+ node []string
+}
+
+func NewHashRing(locker Locker) *HashRing {
+ return &HashRing{
+ Locker: locker,
+ head: &node{},
+ nodes: make(map[string]int),
+ datas: make(map[string]map[string]struct{}),
+ rand: rand.New(rand.NewSource(time.Now().UnixNano())),
+ }
+}
+
+func (h *HashRing) search(score int32) *node {
+ move := h.head
+ for level := len(h.head.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil && move.nexts[level].score < score {
+ move = move.nexts[level]
+ }
+
+ if move.nexts[level] != nil && move.nexts[level].score == score {
+ return move.nexts[level]
+ }
+ }
+
+ return nil
+}
+
+func (h *HashRing) floor(score int32) (int32, bool) {
+ if len(h.head.nexts) == 0 {
+ return -1, false
+ }
+
+ move := h.head
+ for level := len(move.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil && move.nexts[level].score < score {
+ move = move.nexts[level]
+ }
+ }
+
+ if move.nexts[0] != nil && move.nexts[0].score == score {
+ return move.nexts[0].score, true
+ }
+
+ if move == h.head {
+ return -1, false
+ }
+
+ return move.score, true
+}
+
+func (h *HashRing) ceiling(score int32) (int32, bool) {
+ if len(h.head.nexts) == 0 {
+ return -1, false
+ }
+
+ move := h.head
+ for level := len(move.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil && move.nexts[level].score < score {
+ move = move.nexts[level]
+ }
+ }
+
+ if move.nexts[0] == nil {
+ return -1, false
+ }
+
+ return move.nexts[0].score, true
+}
+
+func (h *HashRing) first() (int32, bool) {
+ if len(h.head.nexts) == 0 {
+ return -1, false
+ }
+
+ return h.head.nexts[0].score, true
+}
+
+func (h *HashRing) last() (int32, bool) {
+ if len(h.head.nexts) == 0 {
+ return -1, false
+ }
+
+ move := h.head
+ for level := len(move.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil {
+ move = move.nexts[level]
+ }
+ }
+
+ if move == h.head {
+ return -1, false
+ }
+
+ return move.score, true
+}
+
+func (h *HashRing) roll() int {
+ level := 0
+ for h.rand.Intn(2) > 0 {
+ level++
+ }
+
+ return level
+}
+
+func (h *HashRing) Add(ctx context.Context, score int32, nodeId string) error {
+ nodes := h.search(score)
+ if nodes != nil {
+ for _, _nodeId := range nodes.node {
+ if _nodeId == nodeId {
+ return errors.New("alrady inset")
+ }
+ }
+
+ nodes.node = append(nodes.node, nodeId)
+ return nil
+ }
+
+ roll := h.roll()
+ if len(h.head.nexts)-1 < roll {
+ diff := make([]*node, roll+1-len(h.head.nexts))
+ h.head.nexts = append(h.head.nexts, diff...)
+ }
+
+ newNode := &node{
+ score: score,
+ node: []string{nodeId},
+ nexts: make([]*node, roll+1),
+ }
+
+ move := h.head
+ for level := len(move.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil && move.nexts[level].score < score {
+ move = move.nexts[level]
+ }
+
+ if level > roll {
+ continue
+ }
+
+ newNode.nexts[level] = move.nexts[level]
+ move.nexts[level] = newNode
+ }
+
+ return nil
+}
+
+func (h *HashRing) Del(ctx context.Context, score int32, nodeId string) error {
+ nodes := h.search(score)
+ if nodes == nil {
+ return errors.New("no such score node")
+ }
+
+ index := -1
+ for idx, key := range nodes.node {
+ if key == nodeId {
+ index = idx
+ break
+ }
+ }
+
+ if index == -1 {
+ return errors.New("nodes not exists")
+ }
+
+ if len(nodes.node) > 1 {
+ nodes.node = append(nodes.node[:index], nodes.node[index+1:]...)
+ return nil
+ }
+
+ move := h.head
+ for level := len(move.nexts) - 1; level >= 0; level-- {
+ for move.nexts[level] != nil && move.nexts[level].score < score {
+ move = move.nexts[level]
+ }
+
+ if move.nexts[level] == nil || move.nexts[level].score > score {
+ continue
+ }
+
+ move.nexts[level] = move.nexts[level].nexts[level]
+ }
+
+ diff := 0
+ for level := len(h.head.nexts) - 1; level >= 0 && h.head.nexts[level] == nil; level-- {
+ diff++
+ }
+
+ if diff > 0 {
+ h.head.nexts = h.head.nexts[:len(h.head.nexts)-diff]
+ }
+
+ return nil
+}
+
+func (h *HashRing) Floor(ctx context.Context, score int32) (int32, error) {
+ if floorScore, ok := h.floor(score); ok {
+ return floorScore, nil
+ }
+
+ lastScore, _ := h.last()
+
+ return lastScore, nil
+}
+
+func (h *HashRing) Ceiling(ctx context.Context, score int32) (int32, error) {
+ if ceillingScore, ok := h.ceiling(score); ok {
+ return ceillingScore, nil
+ }
+
+ firstScore, _ := h.first()
+
+ return firstScore, nil
+}
+
+func (h *HashRing) Node(ctx context.Context, score int32) ([]string, error) {
+ node := h.search(score)
+ if node == nil {
+ return []string{}, errors.New("no such score node")
+ }
+
+ return node.node, nil
+}
+
+func (h *HashRing) Nodes(context.Context) (map[string]int, error) {
+ return h.nodes, nil
+}
+
+func (h *HashRing) AddNodes(ctx context.Context, nodeId string, num int) error {
+ h.nodes[nodeId] = num
+
+ return nil
+}
+
+func (h *HashRing) DelNodes(ctx context.Context, nodeId string) error {
+ delete(h.nodes, nodeId)
+
+ return nil
+}
+
+func (h *HashRing) Datas(ctx context.Context, nodeId string) (map[string]struct{}, error) {
+ return h.datas[nodeId], nil
+}
+
+func (h *HashRing) AddDatas(ctx context.Context, nodeId string, datas map[string]struct{}) error {
+ datasMap := h.datas[nodeId]
+ if datasMap == nil {
+ datasMap = make(map[string]struct{})
+ }
+
+ for keys := range datas {
+ datasMap[keys] = struct{}{}
+ }
+
+ h.datas[nodeId] = datasMap
+
+ return nil
+}
+
+func (h *HashRing) DelDatas(ctx context.Context, nodeId string, datas map[string]struct{}) error {
+ datasMap := h.datas[nodeId]
+ if datasMap == nil {
+ return nil
+ }
+
+ for keys := range datas {
+ delete(datasMap, keys)
+ }
+
+ if len(datasMap) == 0 {
+ delete(h.datas, nodeId)
+ }
+
+ return nil
+}
diff --git a/local/locker.go b/local/locker.go
new file mode 100644
index 0000000..68c0fc2
--- /dev/null
+++ b/local/locker.go
@@ -0,0 +1,8 @@
+package local
+
+import "context"
+
+type Locker interface{
+ Lock(context.Context,int)error
+ UnLock(context.Context)error
+}
\ No newline at end of file
diff --git a/local/os/os.go b/local/os/os.go
new file mode 100644
index 0000000..58dd80f
--- /dev/null
+++ b/local/os/os.go
@@ -0,0 +1,27 @@
+package os
+
+import (
+ "fmt"
+ "os"
+ "runtime"
+ "strings"
+)
+
+func GetCurrentProcessId() int {
+ return os.Getpid()
+}
+
+func GetCurrentGoroutineId() string {
+ buf := make([]byte, 128)
+ buf = buf[:runtime.Stack(buf, false)]
+ stackInfo := string(buf)
+
+ return strings.TrimSpace(strings.Split(strings.Split(stackInfo, "[running]")[0], "goroutine")[1])
+}
+
+func GetCurrentProcessAndGoroutineIDStr() string {
+ pid := GetCurrentProcessId()
+ goroutineID := GetCurrentGoroutineId()
+
+ return fmt.Sprintf("%d_%s", pid, goroutineID)
+}
diff --git a/local/spinlock.go b/local/spinlock.go
new file mode 100644
index 0000000..fd2e2a0
--- /dev/null
+++ b/local/spinlock.go
@@ -0,0 +1,72 @@
+package local
+
+import (
+ "context"
+ "errors"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/local/os"
+ "sync"
+ "time"
+)
+
+type SpinLock struct {
+ mu sync.Mutex
+ token string
+ locked bool
+ expire time.Time
+}
+
+func NewSpinLock() *SpinLock {
+ return &SpinLock{}
+}
+
+func (s *SpinLock) tryLock(expire int) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ now := time.Now()
+ if !s.locked || now.After(s.expire) {
+ s.locked = true
+ s.token = os.GetCurrentProcessAndGoroutineIDStr()
+ s.expire = now.Add(time.Duration(expire) * time.Second)
+ return true
+ }
+
+ return false
+}
+
+func (s *SpinLock) Lock(ctx context.Context, expire int) error {
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ if s.tryLock(expire) {
+ return nil
+ }
+ }
+
+ return errors.New("try lock error")
+}
+
+func (s *SpinLock) UnLock(ctx context.Context) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.token != os.GetCurrentProcessAndGoroutineIDStr() {
+ return errors.New("not my lock")
+ }
+
+ if !s.locked || time.Now().After(s.expire) {
+ return errors.New("try unlock an unlocked lock")
+ }
+
+ s.locked = false
+ s.token = ""
+
+ return nil
+}
diff --git a/migrator.go b/migrator.go
new file mode 100644
index 0000000..d728b7f
--- /dev/null
+++ b/migrator.go
@@ -0,0 +1,278 @@
+package consistencehash
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+)
+
+type Migrator func(ctx context.Context, datas map[string]struct{}, from, to string) error
+
+func (c *ConsistenceHash) decrScore(score int32) int32 {
+ if score == 0 {
+ return (math.MaxInt32 - 1)
+ }
+
+ return (score - 1)
+}
+
+func (c *ConsistenceHash) incrScore(score int32) int32 {
+ if score == (math.MaxInt32 - 1) {
+ return 0
+ }
+
+ return (score + 1)
+}
+
+func (c *ConsistenceHash) migratorIn(ctx context.Context, score int32, nodeId string) (datas map[string]struct{}, from, to string, _err error) {
+ if c.migrator == nil {
+ return
+ }
+
+ nodeIds, err := c.hashRing.Node(ctx, score)
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if len(nodeIds) > 1 {
+ return
+ }
+
+ lastScore, err := c.hashRing.Floor(ctx, c.decrScore(score))
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if lastScore == -1 || lastScore == score {
+ return
+ }
+
+ nextScore, err := c.hashRing.Ceiling(ctx, c.incrScore(score))
+ if err != nil {
+ _err = err
+ }
+
+ if nextScore == -1 || nextScore == score {
+ return
+ }
+
+ patternOne := lastScore > score
+
+ if patternOne {
+ lastScore -= math.MaxInt32
+ }
+
+ nextNode, err := c.hashRing.Node(ctx, nextScore)
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if len(nextNode) == 0 {
+ return
+ }
+
+ if c.getRawNodeId(nextNode[0]) == c.getRawNodeId(nodeId) {
+ return
+ }
+
+ _datas, err := c.hashRing.Datas(ctx, nextNode[0])
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if len(_datas) == 0 {
+ return
+ }
+
+ for key := range _datas {
+ sum := c.encrytor.Hash(key)
+ if patternOne && (sum > (lastScore + math.MaxInt32)) {
+ sum -= math.MaxInt32
+ }
+
+ if sum < lastScore || sum >= score {
+ continue
+ }
+
+ if datas == nil {
+ datas = make(map[string]struct{})
+ }
+
+ datas[key] = struct{}{}
+ }
+
+ to = c.getRawNodeId(nodeId)
+ from = c.getRawNodeId(nextNode[0])
+
+ if err := c.hashRing.DelDatas(ctx, nextNode[0], datas); err != nil {
+ _err = err
+ return
+ }
+
+ if err := c.hashRing.AddDatas(ctx, nodeId, datas); err != nil {
+ _err = err
+ return
+ }
+
+ return
+}
+
+func (c *ConsistenceHash) migratorOut(ctx context.Context, score int32, nodeId string,
+ allDatas *map[string]struct{}) (datas map[string]struct{}, from, to string, _err error) {
+
+ var rawTo string
+
+ defer func () {
+ if _err !=nil{
+ return
+ }
+
+ if to == "" || len(datas)==0{
+ return
+ }
+
+ if _err = c.hashRing.DelDatas(ctx, nodeId, datas); _err != nil {
+ return
+ }
+
+ _err = c.hashRing.AddDatas(ctx, rawTo, datas)
+ }()
+
+ if c.migrator == nil {
+ return
+ }
+
+ nodes, err := c.hashRing.Node(ctx, score)
+ fmt.Println(nodes,score)
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if len(nodes) == 0 {
+ return
+ }
+
+ if nodes[0] != nodeId {
+ return
+ }
+
+ if len(*allDatas) == 0 {
+ return
+ }
+
+ lastScore, err := c.hashRing.Floor(ctx, c.decrScore(score))
+ fmt.Println(lastScore,err)
+ if err != nil {
+ _err = err
+ return
+ }
+
+ if lastScore == -1 {
+ return
+ }
+
+ fmt.Printf("score: %d, lastscore: %d",score,lastScore)
+
+ if lastScore == score {
+ for _, nodeIds := range nodes {
+ if c.getRawNodeId(nodeIds) != c.getRawNodeId(nodeId) {
+ rawTo = nodeIds
+ }
+ }
+
+ if rawTo == "" {
+ return
+ }
+ }
+
+ patternOne := lastScore > score //这里必须定义模式,如果直接让sum > (lastScore + math.MaxInt32)判断,后面括号可能会溢出
+
+ if patternOne {
+ lastScore -= math.MaxInt32
+ }
+
+ fmt.Println("del lastscore", lastScore, "score", score)
+
+ for key := range *allDatas {
+ sum := c.encrytor.Hash(key)
+ fmt.Println("del sum", sum, "getsum", lastScore+math.MaxInt32)
+ if patternOne && (sum > (lastScore + math.MaxInt32)) {
+ sum -= math.MaxInt32
+ }
+
+ if sum < lastScore || sum >= score {
+ continue
+ }
+
+ if datas == nil {
+ datas = make(map[string]struct{})
+ }
+
+ datas[key] = struct{}{}
+ delete(*allDatas, key)
+ }
+
+ from = c.getRawNodeId(nodeId)
+ if rawTo != "" {
+ to = c.getRawNodeId(rawTo)
+ return
+ }
+
+ rawTo, err = c.getNextNode(ctx, score, nodeId, nil)
+ if err != nil {
+ _err = err
+ return
+ }
+
+ to = c.getRawNodeId(rawTo)
+
+ if to == "" {
+ _err = errors.New("unavailable to node")
+ return
+ }
+
+ return
+}
+
+func (c *ConsistenceHash) getNextNode(ctx context.Context, score int32, nodeId string, dup map[int32]struct{}) (string, error) {
+ nextScore, err := c.hashRing.Ceiling(ctx, c.incrScore(score))
+ if err != nil {
+ return "", err
+ }
+
+ if nextScore == -1 {
+ return "", nil
+ }
+
+ if _, ok := dup[nextScore]; ok {
+ return "", nil
+ }
+
+ nodes, err := c.hashRing.Node(ctx, nextScore)
+ if err != nil {
+ return "", err
+ }
+
+ if len(nodes) == 0 {
+ return "", errors.New("un available next node")
+ }
+
+ for _, nodeIds := range nodes {
+ if c.getRawNodeId(nodeIds) != c.getRawNodeId(nodeId) {
+ return nodeIds, nil
+ }
+ }
+
+ if dup == nil {
+ dup = make(map[int32]struct{})
+ }
+ dup[nextScore] = struct{}{}
+
+ return c.getNextNode(ctx, nextScore, nodeId, dup)
+}
diff --git a/option.go b/option.go
new file mode 100644
index 0000000..2e1a7c4
--- /dev/null
+++ b/option.go
@@ -0,0 +1,26 @@
+package consistencehash
+
+type Options struct {
+ replicas int
+ lockExpire int
+}
+
+type Option func(o *Options)
+
+func (o *Options) repire() {
+ if o.replicas <= 0 {
+ o.replicas = 5
+ }
+}
+
+func WithReplicas(replicas int) Option {
+ return func(o *Options) {
+ o.replicas = replicas
+ }
+}
+
+func WithLockExpire(expire int)Option{
+ return func(o *Options) {
+ o.lockExpire=expire
+ }
+}
diff --git a/redis/hashring/hashring.go b/redis/hashring/hashring.go
new file mode 100644
index 0000000..ff369db
--- /dev/null
+++ b/redis/hashring/hashring.go
@@ -0,0 +1,323 @@
+package hashring
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/local"
+ "github.com/redis/go-redis/v9"
+ "strconv"
+)
+
+const (
+ RingPreFix = "redis:consistence_hash:ring"
+ NodePreFix = "redis:consistence_hash:virtual"
+ DataKeysPreFix = "redis:consistence_hash:datakeys"
+)
+
+type RedisHashRing struct {
+ local.Locker
+ client *redis.Client
+ key string
+}
+
+func NewRedisHashRing(client *redis.Client, key string, lock local.Locker) *RedisHashRing {
+ return &RedisHashRing{
+ Locker: lock,
+ client: client,
+ key: key,
+ }
+}
+
+func (h *RedisHashRing) getRingKey() string {
+ return fmt.Sprintf("%s:%s", RingPreFix, h.key)
+}
+
+func (h *RedisHashRing) getNodeKey() string {
+ return fmt.Sprintf("%s:%s", NodePreFix, h.key)
+}
+
+func (h *RedisHashRing) getDataKey(nodeId string) string {
+ return fmt.Sprintf("%s:%s:%s", DataKeysPreFix, h.key, nodeId)
+}
+
+func (h *RedisHashRing) Add(ctx context.Context, score int32, nodeId string) error {
+ res, err := h.client.ZRangeByScore(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: strconv.Itoa(int(score)),
+ Max: strconv.Itoa(int(score)),
+ }).Result()
+
+ if err != nil {
+ return err
+ }
+
+ if len(res) > 1 {
+ return errors.New("score to many element")
+ }
+
+ var nodeIds []string
+ if len(res) == 1 {
+ if err := json.Unmarshal([]byte(res[0]), &nodeIds); err != nil {
+ return err
+ }
+
+ for _, id := range nodeIds {
+ if id == nodeId {
+ return nil
+ }
+ }
+
+ if err := h.client.ZRem(ctx, h.getRingKey(), score).Err(); err != nil {
+ return err
+ }
+ }
+
+ nodeIds = append(nodeIds, nodeId)
+ appendNodeIds, _ := json.Marshal(nodeIds)
+
+ if err := h.client.ZAdd(ctx, h.getRingKey(), redis.Z{
+ Score: float64(score),
+ Member: appendNodeIds,
+ }).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (h *RedisHashRing) Del(ctx context.Context, score int32, nodeId string) error {
+ res, err := h.client.ZRangeByScore(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: strconv.Itoa(int(score)),
+ Max: strconv.Itoa(int(score)),
+ }).Result()
+
+ if err != nil {
+ return err
+ }
+
+ if len(res) != 1 {
+ return errors.New("score to error element")
+ }
+
+ var nodeIds []string
+ if err := json.Unmarshal([]byte(res[0]), &nodeIds); err != nil {
+ return err
+ }
+
+ var (
+ exists bool
+ index int
+ )
+
+ for idx, num := range nodeIds {
+ if num == nodeId {
+ exists = true
+ index = idx
+ break
+ }
+ }
+
+ if !exists {
+ return errors.New("the score have no such nodeid")
+ }
+
+ if err := h.client.ZRemRangeByScore(ctx, h.getRingKey(), strconv.Itoa(int(score)), strconv.Itoa(int(score))).Err(); err != nil {
+ return err
+ }
+
+ if len(nodeIds) <= 1 {
+ return nil
+ }
+
+ nodeIds = append(nodeIds[:index], nodeIds[index+1:]...)
+
+ if err := h.client.ZAdd(ctx, h.getRingKey(), redis.Z{
+ Score: float64(score),
+ Member: nodeIds,
+ }).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (h *RedisHashRing) Floor(ctx context.Context, score int32) (int32, error) {
+ res, err := h.client.ZRevRangeByScoreWithScores(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: "-inf",
+ Max: strconv.Itoa(int(score)),
+ Count: 1,
+ }).Result()
+
+ if err != nil {
+ return -1, err
+ }
+
+ if len(res) == 1 {
+ return int32(res[0].Score), nil
+ }
+
+ res, err = h.client.ZRevRangeByScoreWithScores(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: "-inf",
+ Max: "+inf",
+ Count: 1,
+ }).Result()
+
+ if err != nil {
+ return -1, err
+ }
+
+ if len(res) == 1 {
+ return int32(res[0].Score), nil
+ }
+
+ return -1, nil
+}
+
+func (h *RedisHashRing) Ceiling(ctx context.Context, score int32) (int32, error) {
+ res, err := h.client.ZRangeByScoreWithScores(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: strconv.Itoa(int(score)),
+ Max: "+inf",
+ Count: 1,
+ }).Result()
+
+ if err != nil {
+ return -1, err
+ }
+
+ if len(res) == 1 {
+ return int32(res[0].Score), nil
+ }
+
+ res, err = h.client.ZRangeByScoreWithScores(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: "-inf",
+ Max: "+inf",
+ Count: 1,
+ }).Result()
+
+ if err != nil {
+ return -1, err
+ }
+
+ if len(res) == 1 {
+ return int32(res[0].Score), nil
+ }
+
+ return -1, nil
+}
+
+func (h *RedisHashRing) Node(ctx context.Context, score int32) ([]string, error) {
+ res, err := h.client.ZRangeByScore(ctx, h.getRingKey(), &redis.ZRangeBy{
+ Min: strconv.Itoa(int(score)),
+ Max: strconv.Itoa(int(score)),
+ }).Result()
+
+ if err != nil {
+ return []string{}, err
+ }
+
+ if len(res) != 1 {
+ return []string{}, errors.New("score to error element")
+ }
+
+ var nodeIds []string
+ if err := json.Unmarshal([]byte(res[0]), &nodeIds); err != nil {
+ return []string{}, err
+ }
+
+ return nodeIds, nil
+}
+
+func (h *RedisHashRing) Nodes(ctx context.Context) (map[string]int, error) {
+ res, err := h.client.HGetAll(ctx, h.getNodeKey()).Result()
+ if err != nil {
+ return map[string]int{}, err
+ }
+
+ var keys map[string]int
+ for key, value := range res {
+ count, err := strconv.Atoi(value)
+ if err != nil {
+ continue
+ }
+
+ if keys == nil {
+ keys = make(map[string]int)
+ }
+
+ keys[key] = count
+ }
+
+ return keys, nil
+}
+
+func (h *RedisHashRing) AddNodes(ctx context.Context, nodeId string, num int) error {
+ if err := h.client.HSet(ctx, h.getNodeKey(), map[string]string{
+ nodeId: strconv.Itoa(num),
+ }).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (h *RedisHashRing) DelNodes(ctx context.Context, nodeId string) error {
+ if err := h.client.HDel(ctx, h.getNodeKey(), nodeId).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (h *RedisHashRing) Datas(ctx context.Context, nodeId string) (map[string]struct{}, error) {
+ res, err := h.client.SMembers(ctx, h.getDataKey(nodeId)).Result()
+ if err != nil {
+ return map[string]struct{}{}, err
+ }
+
+ var datas map[string]struct{}
+
+ for _, keys := range res {
+ if datas == nil {
+ datas = make(map[string]struct{})
+ }
+
+ datas[keys] = struct{}{}
+ }
+
+ return datas, nil
+}
+
+func (h *RedisHashRing) AddDatas(ctx context.Context, nodeId string, datas map[string]struct{}) error {
+ var addkeys []string
+ for key := range datas {
+ addkeys = append(addkeys, key)
+ }
+
+ if len(addkeys) == 0 {
+ return nil
+ }
+
+ if err := h.client.SAdd(ctx, h.getDataKey(nodeId), addkeys).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (h *RedisHashRing) DelDatas(ctx context.Context, nodeId string, datas map[string]struct{}) error {
+ var addkeys []string
+ for key := range datas {
+ addkeys = append(addkeys, key)
+ }
+
+ if len(addkeys) == 0 {
+ return nil
+ }
+
+ if err := h.client.SRem(ctx, h.getDataKey(nodeId), addkeys).Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/redis/hashring/lock.go b/redis/hashring/lock.go
new file mode 100644
index 0000000..ada81ec
--- /dev/null
+++ b/redis/hashring/lock.go
@@ -0,0 +1,29 @@
+package hashring
+
+import (
+ "context"
+ "git.zhangshuocauc.cn/redhat/goconsistencehash/redis/lock"
+ "github.com/redis/go-redis/v9"
+)
+
+type RedisLockWrap struct {
+ client *redis.Client
+ key string
+ lock *lock.RedisLock
+}
+
+func NewRedisLockWrap(client *redis.Client, key string) *RedisLockWrap {
+ return &RedisLockWrap{
+ client: client,
+ key: key,
+ }
+}
+
+func (r *RedisLockWrap) Lock(ctx context.Context, expire int) error {
+ r.lock = lock.NewRedisLock(r.client, r.key, lock.WithExpireTime(expire))
+ return r.lock.Lock(ctx)
+}
+
+func (r *RedisLockWrap) UnLock(ctx context.Context) error {
+ return r.lock.UnLock(ctx)
+}
diff --git a/redis/lock/lock.go b/redis/lock/lock.go
new file mode 100644
index 0000000..fdd8f76
--- /dev/null
+++ b/redis/lock/lock.go
@@ -0,0 +1,174 @@
+package lock
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "github.com/redis/go-redis/v9"
+ "sync/atomic"
+ "time"
+)
+
+var ErrorsOtherLock = errors.New("not my lock")
+
+type RedisLock struct {
+ client *redis.Client
+ token string
+ key string
+ opts Options
+
+ isRuning int32
+ calFn context.CancelFunc
+}
+
+func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock {
+ redisLock := &RedisLock{
+ client: client,
+ key: key,
+ token: fmt.Sprintf("%d", time.Now().UnixNano()),
+ }
+
+ for _, o := range opts {
+ o(&redisLock.opts)
+ }
+
+ redisLock.opts.repire()
+
+ return redisLock
+}
+
+func (r *RedisLock) tryLock(ctx context.Context) error {
+ res, err := r.client.SetNX(ctx, r.key, r.token, time.Second*time.Duration(r.opts.expireTime)).Result()
+ if err != nil {
+ return err
+ }
+
+ if !res {
+ return errors.New("not my lock")
+ }
+
+ return nil
+}
+
+func (r *RedisLock) reNew(ctx context.Context) error {
+ ires, err := r.client.Eval(ctx, LuaReNew, []string{r.key}, r.token, r.opts.expireTime).Result()
+ if err != nil {
+ return err
+ }
+
+ if res, ok := ires.(int64); ok && res == 1 {
+ return nil
+ }
+
+ return ErrorsOtherLock
+}
+
+func (r *RedisLock) unLock(ctx context.Context) error {
+ ires, err := r.client.Eval(ctx, LuaUnLock, []string{r.key}, r.token).Result()
+ if err != nil {
+ return err
+ }
+
+ if res, ok := ires.(int64); ok && res == 1 {
+ return nil
+ }
+
+ return ErrorsOtherLock
+}
+
+func (r *RedisLock) block(ctx context.Context) error {
+ var after <-chan time.Time
+ if r.opts.maxWaitTime > 0 {
+ after = time.After(time.Duration(r.opts.maxWaitTime))
+ }
+
+ ticker := time.NewTicker(200 * time.Millisecond)
+ defer ticker.Stop()
+ for range ticker.C {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-after:
+ return errors.New("time out")
+ default:
+ }
+
+ err := r.tryLock(ctx)
+
+ if err == nil {
+ return nil
+ }
+
+ if !errors.Is(err, ErrorsOtherLock) {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (r *RedisLock) watchDog(ctx context.Context) {
+ ticker := time.NewTicker(2 * time.Second)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ _ = r.reNew(ctx)
+ }
+}
+
+func (r *RedisLock) watch(ctx context.Context) {
+ if !r.opts.reNew {
+ return
+ }
+
+ for !atomic.CompareAndSwapInt32(&r.isRuning, 0, 1) {
+ }
+
+ var ctxn context.Context
+ ctxn, r.calFn = context.WithCancel(ctx)
+
+ atomic.StoreInt32(&r.isRuning, 1)
+
+ go func() {
+ defer func() {
+ atomic.StoreInt32(&r.isRuning, 0)
+ }()
+ r.watchDog(ctxn)
+ }()
+}
+
+func (r *RedisLock) Lock(ctx context.Context) (err error) {
+ defer func() {
+ if err == nil {
+ r.calFn = nil
+ r.watch(ctx)
+ }
+ }()
+
+ err = r.tryLock(ctx)
+ if err == nil {
+ return nil
+ }
+
+ if !r.opts.block || !errors.Is(err, ErrorsOtherLock) {
+ return err
+ }
+
+ err = r.block(ctx)
+
+ return err
+}
+
+func (r *RedisLock) UnLock(ctx context.Context) error {
+ if r.calFn != nil {
+ r.calFn()
+ }
+
+ return r.unLock(ctx)
+}
diff --git a/redis/lock/lua.go b/redis/lock/lua.go
new file mode 100644
index 0000000..236a3c7
--- /dev/null
+++ b/redis/lock/lua.go
@@ -0,0 +1,17 @@
+package lock
+
+const LuaReNew=`
+ if (redis.call("get",KEYS[1]) == ARGV[1]) then
+ return redis.call("expire",KEYS[1],ARGV[2])
+ else
+ return 0
+ end
+`
+
+const LuaUnLock=`
+ if (redis.call("get",KEYS[1]) == ARGV[1]) then
+ return redis.call("del",KEYS[1])
+ else
+ return 0
+ end
+`
\ No newline at end of file
diff --git a/redis/lock/option.go b/redis/lock/option.go
new file mode 100644
index 0000000..fa1a8b6
--- /dev/null
+++ b/redis/lock/option.go
@@ -0,0 +1,35 @@
+package lock
+
+type Options struct {
+ maxWaitTime int
+ expireTime int
+ block bool
+ reNew bool
+}
+
+type Option func(o *Options)
+
+func (o *Options) repire() {
+ if o.expireTime <= 0 {
+ o.expireTime = 10
+ o.reNew = true
+ }
+}
+
+func WithExpireTime(expire int) Option {
+ return func(o *Options) {
+ o.expireTime = expire
+ }
+}
+
+func WithBlock() Option {
+ return func(o *Options) {
+ o.block = true
+ }
+}
+
+func WithMaxWaitTime(watiTime int) Option {
+ return func(o *Options) {
+ o.maxWaitTime = watiTime
+ }
+}