209 lines
3.6 KiB
Go
209 lines
3.6 KiB
Go
package goredislock
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/redis/go-redis/v9"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const RedisLockPrefix = "goredislock"
|
|
|
|
var ErrorLockByOthers = errors.New("lock by other")
|
|
|
|
type RedisLock struct {
|
|
opts *Options
|
|
client *redis.Client
|
|
key string
|
|
token string
|
|
|
|
running int32
|
|
celFu context.CancelFunc
|
|
}
|
|
|
|
func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock {
|
|
redisLock := &RedisLock{
|
|
client: client,
|
|
key: key,
|
|
opts: &Options{},
|
|
token: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(redisLock.opts)
|
|
}
|
|
|
|
redisLock.opts.repair()
|
|
|
|
return redisLock
|
|
}
|
|
|
|
func (r *RedisLock) getRedisLockKey(key string) string {
|
|
return fmt.Sprintf("%s_%s", RedisLockPrefix, key)
|
|
}
|
|
|
|
func (r *RedisLock) isMyLock(ctx context.Context) bool {
|
|
s, err := r.client.Get(ctx, r.getRedisLockKey(r.key)).Result()
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if s == r.getRedisLockKey(r.key) {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (r *RedisLock) tryLock(ctx context.Context) error {
|
|
result, err := r.client.SetNX(ctx, r.getRedisLockKey(r.key), r.token,
|
|
time.Duration(r.opts.expireTime)*time.Second).Result()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !result {
|
|
return ErrorLockByOthers
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *RedisLock) watchDog(ctx context.Context) {
|
|
ticker := time.NewTicker(3 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
// 执行续期
|
|
_ = r.renewal(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *RedisLock) renewal(ctx context.Context) error {
|
|
result, err := r.client.Eval(ctx, LuaRenewal,
|
|
[]string{r.getRedisLockKey(r.key)}, r.token,
|
|
r.opts.expireTime).Result()
|
|
|
|
fmt.Println(result, err)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res, ok := result.(int64); ok && res == 1 {
|
|
return nil
|
|
}
|
|
|
|
return ErrorLockByOthers
|
|
}
|
|
|
|
func (r *RedisLock) unlock(ctx context.Context) error {
|
|
result, err := r.client.Eval(ctx, LuaUnlock,
|
|
[]string{r.getRedisLockKey(r.key)}, r.token).Result()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if res, ok := result.(int64); ok && res == 1 {
|
|
return nil
|
|
}
|
|
|
|
return ErrorLockByOthers
|
|
}
|
|
|
|
func (r *RedisLock) watch(ctx context.Context) {
|
|
// 如果用户指定过期时间则不进行续期
|
|
if !r.opts.reNewLock {
|
|
return
|
|
}
|
|
|
|
for !atomic.CompareAndSwapInt32(&r.running, 0, 1) {
|
|
}
|
|
|
|
var nctx context.Context
|
|
nctx, r.celFu = context.WithCancel(ctx)
|
|
|
|
go func() {
|
|
defer func() {
|
|
atomic.StoreInt32(&r.running, 0)
|
|
}()
|
|
r.watchDog(nctx)
|
|
}()
|
|
}
|
|
|
|
func (r *RedisLock) block(ctx context.Context) error {
|
|
var after <-chan time.Time
|
|
if r.opts.waitTime > 0 {
|
|
after = time.After(time.Duration(r.opts.waitTime) * time.Second)
|
|
}
|
|
|
|
ticker := time.NewTicker(200 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-after:
|
|
return errors.New("timeout")
|
|
case <-ticker.C:
|
|
err := r.tryLock(ctx)
|
|
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if !errors.Is(err, ErrorLockByOthers) {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *RedisLock) Lock(ctx context.Context) (err error) {
|
|
defer func() {
|
|
// 如果正常情况需要考虑是否启动续期
|
|
if err == nil {
|
|
r.watch(ctx)
|
|
}
|
|
}()
|
|
|
|
if r.isMyLock(ctx) {
|
|
return nil
|
|
}
|
|
|
|
err = r.tryLock(ctx)
|
|
|
|
// 正常得到锁则返回即可
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// 如果不是被占用错误,或者是非阻塞模式 则需要返回
|
|
if !errors.Is(err, ErrorLockByOthers) || !r.opts.block {
|
|
return err
|
|
}
|
|
|
|
// 进入阻塞模式不停尝试获取锁
|
|
return r.block(ctx)
|
|
}
|
|
|
|
func (r *RedisLock) Unlock(ctx context.Context) (err error) {
|
|
err = r.unlock(ctx)
|
|
|
|
if r.celFu != nil {
|
|
r.celFu()
|
|
}
|
|
|
|
return err
|
|
}
|