first commit
This commit is contained in:
commit
992d0e39f2
9
go.mod
Normal file
9
go.mod
Normal file
@ -0,0 +1,9 @@
|
||||
module git.zhangshuocauc.cn/redhat/goredislock
|
||||
|
||||
go 1.24
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/redis/go-redis/v9 v9.7.3 // indirect
|
||||
)
|
6
go.sum
Normal file
6
go.sum
Normal file
@ -0,0 +1,6 @@
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
208
lock.go
Normal file
208
lock.go
Normal file
@ -0,0 +1,208 @@
|
||||
package goredislock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const RedisLockPrefix = "goredislock"
|
||||
|
||||
var ErrorLockByOthers = errors.New("lock by other")
|
||||
|
||||
type RedisLock struct {
|
||||
opts *Options
|
||||
client *redis.Client
|
||||
key string
|
||||
token string
|
||||
|
||||
running int32
|
||||
celFu context.CancelFunc
|
||||
}
|
||||
|
||||
func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock {
|
||||
redisLock := &RedisLock{
|
||||
client: client,
|
||||
key: key,
|
||||
opts: &Options{},
|
||||
token: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(redisLock.opts)
|
||||
}
|
||||
|
||||
redisLock.opts.repair()
|
||||
|
||||
return redisLock
|
||||
}
|
||||
|
||||
func (r *RedisLock) getRedisLockKey(key string) string {
|
||||
return fmt.Sprintf("%s_%s", RedisLockPrefix, key)
|
||||
}
|
||||
|
||||
func (r *RedisLock) isMyLock(ctx context.Context) bool {
|
||||
s, err := r.client.Get(ctx, r.getRedisLockKey(r.key)).Result()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s == r.getRedisLockKey(r.key) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *RedisLock) tryLock(ctx context.Context) error {
|
||||
result, err := r.client.SetNX(ctx, r.getRedisLockKey(r.key), r.token,
|
||||
time.Duration(r.opts.expireTime)*time.Second).Result()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !result {
|
||||
return ErrorLockByOthers
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisLock) watchDog(ctx context.Context) {
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
// 执行续期
|
||||
_ = r.renewal(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisLock) renewal(ctx context.Context) error {
|
||||
result, err := r.client.Eval(ctx, LuaRenewal,
|
||||
[]string{r.getRedisLockKey(r.key)}, r.token,
|
||||
r.opts.expireTime).Result()
|
||||
|
||||
fmt.Println(result, err)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res, ok := result.(int64); ok && res == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrorLockByOthers
|
||||
}
|
||||
|
||||
func (r *RedisLock) unlock(ctx context.Context) error {
|
||||
result, err := r.client.Eval(ctx, LuaUnlock,
|
||||
[]string{r.getRedisLockKey(r.key)}, r.token).Result()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res, ok := result.(int64); ok && res == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrorLockByOthers
|
||||
}
|
||||
|
||||
func (r *RedisLock) watch(ctx context.Context) {
|
||||
// 如果用户指定过期时间则不进行续期
|
||||
if !r.opts.reNewLock {
|
||||
return
|
||||
}
|
||||
|
||||
for !atomic.CompareAndSwapInt32(&r.running, 0, 1) {
|
||||
}
|
||||
|
||||
var nctx context.Context
|
||||
nctx, r.celFu = context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
atomic.StoreInt32(&r.running, 0)
|
||||
}()
|
||||
r.watchDog(nctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *RedisLock) block(ctx context.Context) error {
|
||||
var after <-chan time.Time
|
||||
if r.opts.waitTime > 0 {
|
||||
after = time.After(time.Duration(r.opts.waitTime) * time.Second)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-after:
|
||||
return errors.New("timeout")
|
||||
case <-ticker.C:
|
||||
err := r.tryLock(ctx)
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, ErrorLockByOthers) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisLock) Lock(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
// 如果正常情况需要考虑是否启动续期
|
||||
if err == nil {
|
||||
r.watch(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
if r.isMyLock(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = r.tryLock(ctx)
|
||||
|
||||
// 正常得到锁则返回即可
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果不是被占用错误,或者是非阻塞模式 则需要返回
|
||||
if !errors.Is(err, ErrorLockByOthers) || !r.opts.block {
|
||||
return err
|
||||
}
|
||||
|
||||
// 进入阻塞模式不停尝试获取锁
|
||||
return r.block(ctx)
|
||||
}
|
||||
|
||||
func (r *RedisLock) Unlock(ctx context.Context) (err error) {
|
||||
err = r.unlock(ctx)
|
||||
|
||||
if r.celFu != nil {
|
||||
r.celFu()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
50
lock_test.go
Normal file
50
lock_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package goredislock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_blockingLock(t *testing.T) {
|
||||
// 请输入 redis 节点的地址和密码
|
||||
addr := "192.168.8.1:6379"
|
||||
passwd := ""
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: passwd,
|
||||
})
|
||||
lock1 := NewRedisLock(client, "mylock", WithExpireTime(3))
|
||||
lock2 := NewRedisLock(client, "mylock", WithBlock())
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := lock1.Lock(ctx); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := lock2.Lock(ctx); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Second)
|
||||
if err := lock2.Unlock(ctx); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
t.Log("success")
|
||||
}
|
17
lua.go
Normal file
17
lua.go
Normal file
@ -0,0 +1,17 @@
|
||||
package goredislock
|
||||
|
||||
const LuaRenewal = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call('expire', KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`
|
||||
|
||||
const LuaUnlock = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`
|
37
option.go
Normal file
37
option.go
Normal file
@ -0,0 +1,37 @@
|
||||
package goredislock
|
||||
|
||||
const DefaultExpireTime = 10
|
||||
|
||||
type Options struct {
|
||||
block bool
|
||||
waitTime int
|
||||
expireTime int
|
||||
reNewLock bool
|
||||
}
|
||||
|
||||
func (o *Options) repair() {
|
||||
if o.expireTime == 0 {
|
||||
o.expireTime = DefaultExpireTime
|
||||
o.reNewLock = true
|
||||
}
|
||||
}
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
func WithBlock() Option {
|
||||
return func(o *Options) {
|
||||
o.block = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithWaitTime(waitTime int) Option {
|
||||
return func(o *Options) {
|
||||
o.waitTime = waitTime
|
||||
}
|
||||
}
|
||||
|
||||
func WithExpireTime(expireTime int) Option {
|
||||
return func(o *Options) {
|
||||
o.expireTime = expireTime
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user