利用Redis?lua实现高效读写锁的代码实例
目录
- 前言
- 一、为什么使用Lua
- 二、执行流程
- 三、代码详解
- lua\lock.lua
- lua\refresh.lua
- lua\rlock.lua
- lua\unlock.lua
- 写优先还是读优先?
- 写锁是如何阻塞写锁的?
- 读锁与读锁之间互斥吗?
- 写锁会有被饿死的情况吗?
- 抽象lock类
- Options
- redismutex
- 测试用例
前言
读写锁的好处就是能帮助客户读到的数据一定是最新的,写锁是排他锁,而读锁是一个共享锁,如果写锁一直存在,那么读取数据就要一直等待,直到写入数据完成才能看到,保证了数据的一致性
一、为什么使用Lua
Lua脚本是高并发、高性能的必备脚本语言, 大部分的开源框架(如:redission)中的分布式锁组件,都是用纯lua脚本实现的。
那么,为什么要使用Lua语言来实现分布式锁呢?我们从一个案例看起:
所以,只有确保判断锁和删除锁是一步操作时,才能避免上面的问题,才能确保原子性。
其实很简单,首先获取锁对应的value值,检查是否与requestId相等,如果相等则删除锁(解锁)。虽然看似做了两件事,但是却只有一个完整的原子操作。
第一行代码,我们写了一个简单的 Lua 脚本代码; 第二行代码,我们将Lua代码传到 edis.eval()方法里,并使参数 KEYS[1] 赋值为 lockKey,ARGV[1] 赋值为 requestId,eval() 方法是将Lua代码交给 Redis 服务端执行。
二、执行流程
加锁和删除锁的操作,使用纯 Lua 进行封装,保障其执行时候的原子性。
基于纯Lua脚本实现分布式锁的执行流程,大致如下:
三、代码详解
lua\lock.lua
-- KEYS = [LOCK_KEY, LOCK_INTENT] -- ARGV = [LOCK_ID, TTL] local t = redis.call('TYPE', KEYS[1])["ok"] if t == "string" then return redis.call('PTTL', KEYS[1]) end if redis.call("EXISTS", KEYS[2]) == 1 then return redis.call('PTTL', KEYS[2]) end redis.call('SADD', KEYS[1], ARGV[1]) redis.call('PEXPIRE', KEYS[1], ARGV[2]) return nil
-- KEYS = [LOCK_KEY, LOCK_INTENT]
和-- ARGV = [LOCK_ID, TTL, ENABLE_LOCK_INTENT]
if not redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2], "NX") then
行首使用了redis.call
函数调用,将 LOCK_KEY 和 LOCK_ID 存储到 Redis 中,并设置过期时间为 TTL。如果设置失败,则进入条件内部。- 在条件内部,判断 ENABLE_LOCK_INTENT 的值。如果为 1,则执行
redis.call("SET", KEYS[2], 1, "PX", ARGV[2])
,将 LOCK_INTENT 键设置为 1,并设置与 LOCK_KEY 相同的过期时间。这是为了表示锁被占用的意图。 - 返回
redis.call("PTTL", KEYS[1])
,即 LOCK_KEY 的剩余过期时间,以毫秒为单位。这是为了告知调用方锁已被占用,返回锁的剩余过期时间。 - 若上述条件都不满足,则执行
redis.call("DEL", KEYS[2])
,删除 LOCK_INTENT 键。 - 返回
nil
,表示锁已成功获取。
它首先尝试通过 SET 命令将 LOCK_KEY 存储到 Redis 中,如果设置失败,则表示锁已被其他进程占用,返回锁的剩余过期时间。如果设置成功,则删除 LOCK_INTENT 键,表示锁已成功获取
lua\refresh.lua
-- KEYS = [LOCK_KEY] -- ARGV = [LOCK_ID, TTL] local t = redis.call('TYPE', KEYS[1])["ok"] if (t == "string" and redis.call('GET', KEYS[1]) ~= ARGV[1]) or (t == "set" and redis.call('SISMEMBER', KEYS[1], ARGV[1]) == 0) or (t == "none") then return 0 end return redis.call('PEXPIRE', KEYS[1], ARGV[2])
- 延长锁的时间
lua\rlock.lua
-- KEYS = [LOCK_KEY, LOCK_INTENT] -- ARGV = [LOCK_ID, TTL] local t = redis.call('TYPE', KEYS[1])["ok"] if t == "string" then return redis.call('PTTL', KEYS[1]) end if redis.call("EXISTS", KEYS[2]) == 1 then return redis.call('PTTL', KEYS[2]) end redis.call('SADD', KEYS[1], ARGV[1]) redis.call('PEXPIRE', KEYS[1], ARGV[2]) return nil
local t = redis.call('TYPE', KEYS[1])["ok"]
通过TYPE
命令获取键的类型,并将结果存储在变量t
中。
使用条件逻辑判断锁的状态:
- 如果
t
是字符串,则返回PTTL
命令的结果,即锁的剩余过期时间。 - 如果
LOCK_INTENT
键存在,则返回PTTL
命令的结果,即锁占用意图的剩余过期时间。
- 如果
由于以上条件都不满足,即锁未被占用,将锁 ID (ARGV[1]
) 添加到LOCK_KEY
集合中。
使用PEXPIRE
命令设置LOCK_KEY
的过期时间为ARGV[2]
(以毫秒为单位)。
返回nil
,表示锁已成功获取。
lua\unlock.lua
-- KEYS = [LOCK_KEY] -- ARGV = [LOCK_ID] local t = redis.call('TYPE', KEYS[1])["ok"] if t == "string" and redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('DEL', KEYS[1]) elseif t == "set" and redis.call('SISMEMBER', KEYS[1], ARGV[1]) == 1 then redis.call('SREM', KEYS[1], ARGV[1]) if redis.call('SCARD', KEYS[1]) == 0 then return redis.call('DEL', KEYS[1]) end end return 1
- 检查指定键的类型,如果是字符串并且键的值等于给定的
ARGV
值,则删除该键。 - 如果指定键的类型是集合,并且集合中包含给定的
ARGV
值,则将该值从集合中移除。随后,如果集合中不再包含任何元素,则删除该键。
写优先还是读优先?
写锁会阻塞读锁,所以是写优先
写锁是如何阻塞写锁的?
如果当前的写锁已经被占用,其他写锁的获取请求会被阻塞,因为在释放锁的逻辑中,会先判断锁的类型,如果是写锁,则会判断当前锁的值是否符合预期,从而判断能否删除该锁。
读锁与读锁之间互斥吗?
对于读锁而言,多个读锁之间是可以并发持有的,因此读锁之间默认是不会互斥的,可以同时执行读操作。
写锁会有被饿死的情况吗?
写优先锁可以保证写线程不会饿死,但是如果一直有写线程获取写锁,读线程也会被「饿死」。
既然不管优先读锁还是写锁,对方可能会出现饿死问题,那么我们就不偏袒任何一方,搞个「公平读写锁」。
公平读写锁比较简单的一种方式是:用队列把获取锁的线程排队,不管是写线程还是读线程都按照先进先出的原则加锁即可,这样读线程仍然可以并发,也不会出现「饥饿」的现象。
抽象lock类
import ( "context" "errors" "time" "github.com/redis/go-redis/v9" ) var _ context.Context = (*Lock)(nil) // Lock represents a lock with context. type Lock struct { redis redis.Scripter id string ttl time.Duration key string log LogFunc ctx context.Context cancel context.CancelFunc } // ID returns the id value set by the lock. func (l *Lock) ID() string { return l.id } // Key returns the key value set by the lock. func (l *Lock) Key() string { return l.key } func (l *Lock) Deadline() (deadline time.Time, ok bool) { return l.ctx.Deadline() } func (l *Lock) Done() <-chan struct{} { return l.ctx.Done() } func (l *Lock) Err() error { return l.ctx.Err() } func (l *Lock) Value(key any) any { return l.ctx.Value(key) } // Unlock unlocks. func (l *Lock) Unlock() { l.cancel() _, err := scriptUnlock.Run(context.Background(), l.redis, []string{l.key}, l.id).Result() if err != nil { l.log("[ERROR] unlock %q %s: %v", l.key, l.id, err) } } func (l *Lock) refreshTTL(left time.Time) { defer l.cancel() refresh := l.updateTTL() for { diff := time.Since(left) select { case <-l.ctx.Done(): return case <-time.After(-diff): // cant refresh return case <-time.After(refresh): status, err := scriptRefresh.Run(l.ctx, l.redis, []string{l.key}, l.id, l.ttl.Milliseconds()).Int() if err != nil { if errors.Is(err, context.Canceled) { return } refresh = refreshTimeout l.log("[ERROR] refresh key %q %s: %v", l.key, l.id, err) continue } left = l.leftTTL() refresh = l.updateTTL() if status == 0 { l.log("[ERROR] refresh key %q %s already expired", l.key, l.id) return } } } } func (l *Lock) leftTTL() time.Time { return time.Now().Add(l.ttl) } func (l *Lock) updateTTL() time.Duration { return l.ttl / 2 }
ID()
:返回锁的ID。Key()
:返回锁的键名。Deadline()
:返回锁的截止时间和标志,如果没有设置则返回零值。Done()
:返回一个通道,在锁的上下文被取消或者锁过期后会被关闭。Err()
:返回锁的错误状态。Value(key any) any
:返回一个键关联的值,用于传递上下文相关的数据。Unlock()
:解锁操作,会取消锁的上下文,并调用Redis的脚本解锁操作。refreshTTL(left time.Time)
:刷新锁的过期时间,定期更新Redis中锁的过期时间,直到锁的上下文被取消、锁过期或无法继续刷新为止。leftTTL()
:返回锁的剩余过期时间。updateTTL()
:更新刷新锁的间隔时间。每次减少一半
为什么需要为什么l.ttl / 2
这是为了实现锁的自动续约。通过定期刷新锁的过期时间,可以确保锁在使用过程中不会过期而被意外释放。
这种做法可以在以下情况下带来一些好处:
- 减少锁的续约操作对Redis的压力:由于续约操作是相对较昂贵的,通过将过期时间缩短为原来的一半,可以降低续约的频率,从而减少对Redis的请求,减少了网络和计算资源的消耗。
- 避免长时间持有锁带来的问题:如果某个持有锁的进程/线程发生故障或延迟,导致无法及时释放锁,那么其他进程可能会长时间等待获取该锁,造成资源浪费。通过定期刷新锁的过期时间,可以在锁即将过期之前及时释放锁,降低该问题的风险。
Options
package redismutex import ( "context" "log" "os" "sync" "time" ) const ( lenBytesID = 16 refreshTimeout = time.Millisecond * 500 defaultKeyTTL = time.Second * 4 ) var ( globalMx sync.RWMutex globalLog = func() LogFunc { l := log.New(os.Stderr, "redismutex: ", log.LstdFlags) return func(format string, v ...any) { l.Printf(format, v...) } }() ) // LogFunc type is an adapter to allow the use of ordinary functions as LogFunc. type LogFunc func(format string, v ...any) // NopLog logger does nothing var NopLog = LogFunc(func(string, ...any) {}) // SetLog sets the logger. func SetLog(l LogFunc) { globalMx.Lock() defer globalMx.Unlock() if l != nil { globalLog = l } } // MutexOption is the option for the mutex. type MutexOption func(*mutexOptions) type mutexOptions struct { name string ttl time.Duration lockIntent bool log LogFunc } // WithTTL sets the TTL of the mutex. func WithTTL(ttl time.Duration) MutexOption { return func(o *mutexOptions) { if ttl >= time.Second*2 { o.ttl = ttl } } } // WithLockIntent sets the lock intent. func WithLockIntent() MutexOption { return func(o *mutexOptions) { o.lockIntent = true } } // LockOption is the option for the lock. type LockOption func(*lockOptions) type lockOptions struct { ctx context.Context key string lockIntentKey string enableLockIntent int ttl time.Duration log LogFunc } func newLockOptions(m mutexOptions, opt ...LockOption) lockOptions { opts := lockOptions{ ctx: context.Background(), key: m.name, enableLockIntent: boolToInt(m.lockIntent), ttl: m.ttl, log: m.log, } for _, o := range opt { o(&opts) } opts.lockIntentKey = lockIntentKey(opts.key) return opts } // WithKey sets the key of the lock. func WithKey(key string) LockOption { return func(o *lockOptions) { if key != "" { o.key += ":" + key } } } // WithContext sets the context of the lock. func WithContext(ctx context.Context) LockOption { return func(o *lockOptions) { if ctx != nil { o.ctx = ctx } } } func boolToInt(b bool) int { if b { return 1 } return 0 } func lockIntentKey(key string) string { return key + ":lock-intent" }
SetLog(l LogFunc)
:设置日志记录器。WithTTL(ttl time.Duration)
:设置互斥锁的生存时间(TTL)选项。WithLockIntent()
:设置锁意图选项。newLockOptions(m mutexOptions, opt ...LockOption)
:创建锁的选项。WithKey(key string)
:设置锁的键选项。WithContext(ctx context.Context)
:设置锁的上下文选项。lockIntentKey(key string)
:为给定的锁键生成锁意图键。
可以通过设置选项来控制互斥锁的行为和属性,如生存时间、锁意图、上下文等。还提供了一些实用函数和类型,用于管理互斥锁和生成选项
redismutex
// Package redismutex provides a distributed rw mutex. package redismutex import ( "context" "crypto/rand" "embed" "encoding/hex" "errors" "sync" "time" "github.com/redis/go-redis/v9" ) var ErrLock = errors.New("redismutex: lock not obtained") var ( //go:embed lua lua embed.FS scriptRLock *redis.Script scriptLock *redis.Script scriptRefresh *redis.Script scriptUnlock *redis.Script ) func init() { scriptRLock = redis.NewScript(mustReadFile("rlock.lua")) scriptLock = redis.NewScript(mustReadFile("lock.lua")) scriptRefresh = redis.NewScript(mustReadFile("refresh.lua")) scriptUnlock = redis.NewScript(mustReadFile("unlock.lua")) } // A RWMutex is a distributed mutual exclusion lock. type RWMutex struct { redis redis.Scripter opts mutexOptions id struct { sync.Mutex buf []byte } } // NewMutex creates a new distributed mutex. func NewMutex(rc redis.Scripter, name string, opt ...MutexOption) *RWMutex { globalMx.RLock() defer globalMx.RUnlock() opts := mutexOptions{ name: name, ttl: defaultKeyTTL, log: globalLog, } for _, o := range opt { o(&opts) } rw := &RWMutex{ redis: rc, opts: opts, } rw.id.buf = make([]byte, lenBytesID) return rw } // TryRLock tries to lock for reading and reports whether it succeeded. func (m *RWMutex) TryRLock(opt ...LockOption) (*Lock, bool) { opts := newLockOptions(m.opts, opt...) ctx, _, err := m.rlock(opts) if err != nil { if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] try-read-lock key %q: %v", opts.key, err) } return nil, false } return ctx, true } // RLock locks for reading. func (m *RWMutex) RLock(opt ...LockOption) (*Lock, bool) { opts := newLockOptions(m.opts, opt...) ctx, ttl, err := m.rlock(opts) if err == nil { return ctx, true } if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] read-lock key %q: %v", opts.key, err) return nil, false } for { select { case <-opts.ctx.Done(): m.opts.log("[ERROR] read-lock key %q: %v", opts.key, opts.ctx.Err()) return nil, false case <-time.After(ttl): ctx, ttl, err = m.rlock(opts) if err == nil { return ctx, true } if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] read-lock key %q: %v", opts.key, err) return nil, false } continue } } } // TryLock tries to lock for writing and reports whether it succeeded. func (m *RWMutex) TryLock(opt ...LockOption) (*Lock, bool) { opts := newLockOptions(m.opts, opt...) opts.enableLockIntent = 0 // force disable lock intent ctx, _, err := m.lock(opts) if err != nil { if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] try-lock key %q: %v", opts.key, err) } return nil, false } return ctx, true } // Lock locks for writing. func (m *RWMutex) Lock(opt ...LockOption) (*Lock, bool) { opts := newLockOptions(m.opts, opt...) ctx, ttl, err := m.lock(opts) if err == nil { return ctx, true } if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] lock key %q: %v", opts.key, err) return nil, false } for { select { case <-opts.ctx.Done(): m.opts.log("[ERROR] lock key %q: %v", opts.key, opts.ctx.Err()) return nil, false case <-time.After(ttl): ctx, ttl, err = m.lock(opts) if err == nil { return ctx, true } if !errors.Is(err, ErrLock) { m.opts.log("[ERROR] lock key %q: %v", opts.key, err) return nil, false } continue } } } func (m *RWMutex) lock(opts lockOptions) (*Lock, time.Duration, error) { id, err := m.randomID() if err != nil { return nil, 0, err } pTTL, err := scriptLock.Run(opts.ctx, m.redis, []string{opts.key, opts.lockIntentKey}, id, opts.ttl.Milliseconds(), opts.enableLockIntent).Result() leftTTL := time.Now().Add(opts.ttl) if err == nil { return nil, time.Duration(pTTL.(int64)) * time.Millisecond, ErrLock } if err != redis.Nil { return nil, 0, err } ctx, cancel := context.WithCancel(opts.ctx) lock := &Lock{ redis: m.redis, id: id, ttl: opts.ttl, key: opts.key, log: opts.log, ctx: ctx, cancel: cancel, } go lock.refreshTTL(leftTTL) return lock, 0, nil } func (m *RWMutex) rlock(opts lockOptions) (*Lock, time.Duration, error) { id, err := m.randomID() if err != nil { return nil, 0, err } pTTL, err := scriptRLock.Run(opts.ctx, m.redis, []string{opts.key, opts.lockIntentKey}, id, opts.ttl.Milliseconds()).Result() leftTTL := time.Now().Add(opts.ttl) if err == nil { return nil, time.Duration(pTTL.(int64)) * time.Millisecond, ErrLock } if err != redis.Nil { return nil, 0, err } ctx, cancel := context.WithCancel(opts.ctx) lock := &Lock{ redis: m.redis, id: id, ttl: opts.ttl, key: opts.key, log: opts.log, ctx: ctx, cancel: cancel, } go lock.refreshTTL(leftTTL) return lock, 0, nil } // randomID generates a random hex string with 16 bytes. func (m *RWMutex) randomID() (string, error) { m.id.Lock() defer m.id.Unlock() _, err := rand.Read(m.id.buf) if err != nil { return "", err } return hex.EncodeToString(m.id.buf), nil } func mustReadFile(filename string) string { b, err := lua.ReadFile("lua/" + filename) if err != nil { panic(err) } return string(b) }
- 通过
NewMutex
函数创建一个新的分布式互斥锁。该函数接受 Redis 客户端、锁的名称和一系列选项作为参数,返回一个 RWMutex 结构体实例。 - 通过
RLock
和Lock
方法来获取读锁和写锁。如果无法立即获取锁,则会阻塞等待,直到获取成功或者上下文取消。 - 通过
TryRLock
和TryLock
方法来尝试获取读锁和写锁,如果无法立即获取锁则立即返回失败,不会阻塞。 - 该包实现了一个
Lock
结构体,包含了锁相关的信息和操作方法,比如刷新锁的过期时间。 - 使用
redis.Script
来执行 Lua 脚本,通过 Redis 客户端执行相应的 Redis 命令。 - 使用了
crypto/rand
包来生成随机的锁标识符。 - 最终的
mustReadFile
函数用于读取嵌入的 Lua 脚本文件。
测试用例
package redismutex import ( "context" "errors" "log" "strings" "testing" "time" "github.com/redis/go-redis/v9" ) func init() { SetLog(func(format string, a ...any) { if strings.HasPrefix(format, "[ERROR]") { log.Fatalf(format, a...) } }) } func TestMutex(t *testing.T) { t.Parallel() const lockKey = "mutex" rc := redis.NewClient(redisOpts()) prep(t, rc, lockKey) mx := NewMutex(rc, lockKey) lock, ok := mx.Lock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() assertTTL(t, rc, lockKey, defaultKeyTTL) // try again _, ok = mx.TryLock() if exp, got := false, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } _, ok = mx.TryRLock() if exp, got := false, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } // manually unlock lock.Unlock() // lock again lock, ok = mx.Lock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() } func TestRWMutex(t *testing.T) { t.Parallel() const lockKey = "rw_mutex" rc := redis.NewClient(redisOpts()) prep(t, rc, lockKey) mx := NewMutex(rc, lockKey) lock, ok := mx.RLock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() assertTTL(t, rc, lockKey, defaultKeyTTL) // try again _, ok = mx.TryLock() if exp, got := false, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } // try rlock rlock, ok := mx.TryRLock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } rlock.Unlock() // manually unlock lock.Unlock() // lock again lock, ok = mx.Lock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() } func TestRWMutex_LockIntent(t *testing.T) { t.Parallel() const lockKey = "lock_intent_mutex" rc := redis.NewClient(redisOpts()) prep(t, rc, lockKey) mx := NewMutex(rc, lockKey, WithLockIntent()) lock, ok := mx.RLock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() // mark lock intent _, _, err := mx.lock(newLockOptions(mx.opts)) if exp, got := ErrLock, err; !errors.Is(got, exp) { t.Fatalf("exp %v, got %v", exp, got) } // try rlock _, ok = mx.TryRLock() if exp, got := false, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } // manually unlock lock.Unlock() // lock write lock, ok = mx.Lock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } lock.Unlock() // remove lock intent // lock again lock, ok = mx.RLock() if exp, got := true, ok; exp != got { t.Fatalf("exp %v, got %v", exp, got) } defer lock.Unlock() } func TestRWMutex_ID(t *testing.T) { t.Parallel() rw := &RWMutex{} rw.id.buf = make([]byte, lenBytesID) id, _ := rw.randomID() if exp, got := 32, len(id); exp != got { t.Fatalf("exp %v, got %v", exp, got) } } func prep(t *testing.T, rc *redis.Client, key string) { t.Cleanup(func() { for _, v := range []string{key, lockIntentKey(key)} { if err := rc.Del(context.Background(), v).Err(); err != nil { t.Fatal(err) } } if err := rc.Close(); err != nil { t.Fatal(err) } }) } func assertTTL(t *testing.T, rc *redis.Client, key string, exp time.Duration) { t.Helper() got, err := rc.TTL(context.Background(), key).Result() if exp, got := (any)(nil), err; exp != got { t.Fatalf("exp %v, got %v", exp, got) } delta := got - exp if delta < 0 { delta = 1 - delta } if delta > time.Second { t.Fatalf("exp ~%v, got %v", exp, got) } } func redisOpts() *redis.Options { return &redis.Options{ Network: "tcp", Addr: "0.0.0.0:6379", DB: 9, } }