goredislock/lock.go
2025-04-18 16:01:21 +08:00

209 lines
3.6 KiB
Go

package goredislock
import (
"context"
"errors"
"fmt"
"github.com/redis/go-redis/v9"
"sync/atomic"
"time"
)
const RedisLockPrefix = "goredislock"
var ErrorLockByOthers = errors.New("lock by other")
type RedisLock struct {
opts *Options
client *redis.Client
key string
token string
running int32
celFu context.CancelFunc
}
func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock {
redisLock := &RedisLock{
client: client,
key: key,
opts: &Options{},
token: fmt.Sprintf("%d", time.Now().UnixNano()),
}
for _, opt := range opts {
opt(redisLock.opts)
}
redisLock.opts.repair()
return redisLock
}
func (r *RedisLock) getRedisLockKey(key string) string {
return fmt.Sprintf("%s_%s", RedisLockPrefix, key)
}
func (r *RedisLock) isMyLock(ctx context.Context) bool {
s, err := r.client.Get(ctx, r.getRedisLockKey(r.key)).Result()
if err != nil {
return false
}
if s == r.getRedisLockKey(r.key) {
return true
}
return false
}
func (r *RedisLock) tryLock(ctx context.Context) error {
result, err := r.client.SetNX(ctx, r.getRedisLockKey(r.key), r.token,
time.Duration(r.opts.expireTime)*time.Second).Result()
if err != nil {
return err
}
if !result {
return ErrorLockByOthers
}
return nil
}
func (r *RedisLock) watchDog(ctx context.Context) {
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// 执行续期
_ = r.renewal(ctx)
}
}
}
func (r *RedisLock) renewal(ctx context.Context) error {
result, err := r.client.Eval(ctx, LuaRenewal,
[]string{r.getRedisLockKey(r.key)}, r.token,
r.opts.expireTime).Result()
fmt.Println(result, err)
if err != nil {
return err
}
if res, ok := result.(int64); ok && res == 1 {
return nil
}
return ErrorLockByOthers
}
func (r *RedisLock) unlock(ctx context.Context) error {
result, err := r.client.Eval(ctx, LuaUnlock,
[]string{r.getRedisLockKey(r.key)}, r.token).Result()
if err != nil {
return err
}
if res, ok := result.(int64); ok && res == 1 {
return nil
}
return ErrorLockByOthers
}
func (r *RedisLock) watch(ctx context.Context) {
// 如果用户指定过期时间则不进行续期
if !r.opts.reNewLock {
return
}
for !atomic.CompareAndSwapInt32(&r.running, 0, 1) {
}
var nctx context.Context
nctx, r.celFu = context.WithCancel(ctx)
go func() {
defer func() {
atomic.StoreInt32(&r.running, 0)
}()
r.watchDog(nctx)
}()
}
func (r *RedisLock) block(ctx context.Context) error {
var after <-chan time.Time
if r.opts.waitTime > 0 {
after = time.After(time.Duration(r.opts.waitTime) * time.Second)
}
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-after:
return errors.New("timeout")
case <-ticker.C:
err := r.tryLock(ctx)
if err == nil {
return nil
}
if !errors.Is(err, ErrorLockByOthers) {
return err
}
}
}
}
func (r *RedisLock) Lock(ctx context.Context) (err error) {
defer func() {
// 如果正常情况需要考虑是否启动续期
if err == nil {
r.watch(ctx)
}
}()
if r.isMyLock(ctx) {
return nil
}
err = r.tryLock(ctx)
// 正常得到锁则返回即可
if err == nil {
return nil
}
// 如果不是被占用错误,或者是非阻塞模式 则需要返回
if !errors.Is(err, ErrorLockByOthers) || !r.opts.block {
return err
}
// 进入阻塞模式不停尝试获取锁
return r.block(ctx)
}
func (r *RedisLock) Unlock(ctx context.Context) (err error) {
err = r.unlock(ctx)
if r.celFu != nil {
r.celFu()
}
return err
}