209 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			209 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package goredislock
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"github.com/redis/go-redis/v9"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| const RedisLockPrefix = "goredislock"
 | |
| 
 | |
| var ErrorLockByOthers = errors.New("lock by other")
 | |
| 
 | |
| type RedisLock struct {
 | |
| 	opts   *Options
 | |
| 	client *redis.Client
 | |
| 	key    string
 | |
| 	token  string
 | |
| 
 | |
| 	running int32
 | |
| 	celFu   context.CancelFunc
 | |
| }
 | |
| 
 | |
| func NewRedisLock(client *redis.Client, key string, opts ...Option) *RedisLock {
 | |
| 	redisLock := &RedisLock{
 | |
| 		client: client,
 | |
| 		key:    key,
 | |
| 		opts:   &Options{},
 | |
| 		token:  fmt.Sprintf("%d", time.Now().UnixNano()),
 | |
| 	}
 | |
| 
 | |
| 	for _, opt := range opts {
 | |
| 		opt(redisLock.opts)
 | |
| 	}
 | |
| 
 | |
| 	redisLock.opts.repair()
 | |
| 
 | |
| 	return redisLock
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) getRedisLockKey(key string) string {
 | |
| 	return fmt.Sprintf("%s_%s", RedisLockPrefix, key)
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) isMyLock(ctx context.Context) bool {
 | |
| 	s, err := r.client.Get(ctx, r.getRedisLockKey(r.key)).Result()
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	if s == r.getRedisLockKey(r.key) {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) tryLock(ctx context.Context) error {
 | |
| 	result, err := r.client.SetNX(ctx, r.getRedisLockKey(r.key), r.token,
 | |
| 		time.Duration(r.opts.expireTime)*time.Second).Result()
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if !result {
 | |
| 		return ErrorLockByOthers
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) watchDog(ctx context.Context) {
 | |
| 	ticker := time.NewTicker(3 * time.Second)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			return
 | |
| 		case <-ticker.C:
 | |
| 			// 执行续期
 | |
| 			_ = r.renewal(ctx)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) renewal(ctx context.Context) error {
 | |
| 	result, err := r.client.Eval(ctx, LuaRenewal,
 | |
| 		[]string{r.getRedisLockKey(r.key)}, r.token,
 | |
| 		r.opts.expireTime).Result()
 | |
| 
 | |
| 	fmt.Println(result, err)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if res, ok := result.(int64); ok && res == 1 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return ErrorLockByOthers
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) unlock(ctx context.Context) error {
 | |
| 	result, err := r.client.Eval(ctx, LuaUnlock,
 | |
| 		[]string{r.getRedisLockKey(r.key)}, r.token).Result()
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if res, ok := result.(int64); ok && res == 1 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	return ErrorLockByOthers
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) watch(ctx context.Context) {
 | |
| 	// 如果用户指定过期时间则不进行续期
 | |
| 	if !r.opts.reNewLock {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for !atomic.CompareAndSwapInt32(&r.running, 0, 1) {
 | |
| 	}
 | |
| 
 | |
| 	var nctx context.Context
 | |
| 	nctx, r.celFu = context.WithCancel(ctx)
 | |
| 
 | |
| 	go func() {
 | |
| 		defer func() {
 | |
| 			atomic.StoreInt32(&r.running, 0)
 | |
| 		}()
 | |
| 		r.watchDog(nctx)
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) block(ctx context.Context) error {
 | |
| 	var after <-chan time.Time
 | |
| 	if r.opts.waitTime > 0 {
 | |
| 		after = time.After(time.Duration(r.opts.waitTime) * time.Second)
 | |
| 	}
 | |
| 
 | |
| 	ticker := time.NewTicker(200 * time.Millisecond)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			return ctx.Err()
 | |
| 		case <-after:
 | |
| 			return errors.New("timeout")
 | |
| 		case <-ticker.C:
 | |
| 			err := r.tryLock(ctx)
 | |
| 
 | |
| 			if err == nil {
 | |
| 				return nil
 | |
| 			}
 | |
| 
 | |
| 			if !errors.Is(err, ErrorLockByOthers) {
 | |
| 				return err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) Lock(ctx context.Context) (err error) {
 | |
| 	defer func() {
 | |
| 		// 如果正常情况需要考虑是否启动续期
 | |
| 		if err == nil {
 | |
| 			r.watch(ctx)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	if r.isMyLock(ctx) {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	err = r.tryLock(ctx)
 | |
| 
 | |
| 	// 正常得到锁则返回即可
 | |
| 	if err == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	// 如果不是被占用错误,或者是非阻塞模式 则需要返回
 | |
| 	if !errors.Is(err, ErrorLockByOthers) || !r.opts.block {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// 进入阻塞模式不停尝试获取锁
 | |
| 	return r.block(ctx)
 | |
| }
 | |
| 
 | |
| func (r *RedisLock) Unlock(ctx context.Context) (err error) {
 | |
| 	err = r.unlock(ctx)
 | |
| 
 | |
| 	if r.celFu != nil {
 | |
| 		r.celFu()
 | |
| 	}
 | |
| 
 | |
| 	return err
 | |
| }
 |