package goredislock import ( "context" "errors" "fmt" "github.com/redis/go-redis/v9" "sync/atomic" "time" ) const RedisLockPrefix = "goredislock" var ErrorLockByOthers = errors.New("lock by other") type RedisLock struct { opts *Options client *redis.Client key string token string running int32 celFu context.CancelFunc } func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock { redisLock := &RedisLock{ client: client, key: key, opts: &Options{}, token: fmt.Sprintf("%d", time.Now().UnixNano()), } for _, opt := range opts { opt(redisLock.opts) } redisLock.opts.repair() return redisLock } func (r *RedisLock) getRedisLockKey(key string) string { return fmt.Sprintf("%s_%s", RedisLockPrefix, key) } func (r *RedisLock) isMyLock(ctx context.Context) bool { s, err := r.client.Get(ctx, r.getRedisLockKey(r.key)).Result() if err != nil { return false } if s == r.getRedisLockKey(r.key) { return true } return false } func (r *RedisLock) tryLock(ctx context.Context) error { result, err := r.client.SetNX(ctx, r.getRedisLockKey(r.key), r.token, time.Duration(r.opts.expireTime)*time.Second).Result() if err != nil { return err } if !result { return ErrorLockByOthers } return nil } func (r *RedisLock) watchDog(ctx context.Context) { ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: // 执行续期 _ = r.renewal(ctx) } } } func (r *RedisLock) renewal(ctx context.Context) error { result, err := r.client.Eval(ctx, LuaRenewal, []string{r.getRedisLockKey(r.key)}, r.token, r.opts.expireTime).Result() fmt.Println(result, err) if err != nil { return err } if res, ok := result.(int64); ok && res == 1 { return nil } return ErrorLockByOthers } func (r *RedisLock) unlock(ctx context.Context) error { result, err := r.client.Eval(ctx, LuaUnlock, []string{r.getRedisLockKey(r.key)}, r.token).Result() if err != nil { return err } if res, ok := result.(int64); ok && res == 1 { return nil } return ErrorLockByOthers } func (r *RedisLock) watch(ctx context.Context) { // 如果用户指定过期时间则不进行续期 if !r.opts.reNewLock { return } for !atomic.CompareAndSwapInt32(&r.running, 0, 1) { } var nctx context.Context nctx, r.celFu = context.WithCancel(ctx) go func() { defer func() { atomic.StoreInt32(&r.running, 0) }() r.watchDog(nctx) }() } func (r *RedisLock) block(ctx context.Context) error { var after <-chan time.Time if r.opts.waitTime > 0 { after = time.After(time.Duration(r.opts.waitTime) * time.Second) } ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): return ctx.Err() case <-after: return errors.New("timeout") case <-ticker.C: err := r.tryLock(ctx) if err == nil { return nil } if !errors.Is(err, ErrorLockByOthers) { return err } } } } func (r *RedisLock) Lock(ctx context.Context) (err error) { defer func() { // 如果正常情况需要考虑是否启动续期 if err == nil { r.watch(ctx) } }() if r.isMyLock(ctx) { return nil } err = r.tryLock(ctx) // 正常得到锁则返回即可 if err == nil { return nil } // 如果不是被占用错误,或者是非阻塞模式 则需要返回 if !errors.Is(err, ErrorLockByOthers) || !r.opts.block { return err } // 进入阻塞模式不停尝试获取锁 return r.block(ctx) } func (r *RedisLock) Unlock(ctx context.Context) (err error) { err = r.unlock(ctx) if r.celFu != nil { r.celFu() } return err }