goconsistencehash/consistencehash.go
2025-05-07 10:34:10 +08:00

262 lines
4.6 KiB
Go

package consistencehash
import (
"context"
"errors"
"fmt"
"strings"
"sync"
)
type ConsistenceHash struct {
hashRing HashRing
encrytor Encrytor
migrator Migrator
opts Options
}
func NewConsistenceHash(hashRing HashRing, encrytor Encrytor, migrator Migrator, opts ...Option) *ConsistenceHash {
chash := &ConsistenceHash{
hashRing: hashRing,
encrytor: encrytor,
migrator: migrator,
}
for _, o := range opts {
o(&chash.opts)
}
chash.opts.repire()
return chash
}
func (c *ConsistenceHash) getVirNodeId(id string, index int) string {
return fmt.Sprintf("%s_%d", id, index)
}
func (c *ConsistenceHash) getRawNodeId(id string) string {
index := strings.LastIndex(id, "_")
return id[:index]
}
func (c *ConsistenceHash) migratorHandle(funs []func()) {
wg := sync.WaitGroup{}
for _, fn := range funs {
wg.Add(1)
go func(fn func()) {
defer func() {
if err := recover(); err != nil {
fmt.Println("recover")
}
wg.Done()
}()
fn()
}(fn)
}
wg.Wait()
}
func (c *ConsistenceHash) wight(wight int) int {
if wight < 1 {
return 1
}
if wight > 10 {
return 10
}
return wight
}
func (c *ConsistenceHash) AddNode(ctx context.Context, nodeId string, wight int) error {
if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
return err
}
defer func() {
_ = c.hashRing.UnLock(ctx)
}()
nodes, err := c.hashRing.Nodes(ctx)
if err != nil {
return err
}
for node := range nodes {
if node == nodeId {
return errors.New("node already exists")
}
}
count := c.wight(wight) * c.opts.replicas
if err := c.hashRing.AddNodes(ctx, nodeId, count); err != nil {
return err
}
var migratorFns []func()
for i := 0; i < count; i++ {
vNodeId := c.getVirNodeId(nodeId, i)
sum := c.encrytor.Hash(vNodeId)
if err := c.hashRing.Add(ctx, sum, vNodeId); err != nil {
return err
}
datas, from, to, err := c.migratorIn(ctx, sum, vNodeId)
if err != nil {
return err
}
if len(datas) == 0 || from == "" {
continue
}
migratorFns = append(migratorFns, func() {
_ = c.migrator(ctx, datas, from, to)
})
}
c.migratorHandle(migratorFns)
return nil
}
func (c *ConsistenceHash) getAllDatas(ctx context.Context, nodeId string, counts int) (map[string]struct{}, error) {
var allDatas map[string]struct{}
for i := 0; i < counts; i++ {
vNodeId := c.getVirNodeId(nodeId, i)
_datas, err := c.hashRing.Datas(ctx, vNodeId)
if err != nil {
return allDatas, err
}
if allDatas == nil {
allDatas = make(map[string]struct{})
}
for key := range _datas {
allDatas[key] = struct{}{}
}
}
return allDatas, nil
}
func (c *ConsistenceHash) DelNode(ctx context.Context, nodeId string) error {
if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
return err
}
defer func() {
_ = c.hashRing.UnLock(ctx)
}()
nodes, err := c.hashRing.Nodes(ctx)
if err != nil {
return err
}
var (
exists bool
counts int
)
for node, nums := range nodes {
if node == nodeId {
exists = true
counts = nums
break
}
}
if !exists {
return errors.New("node id is not exists")
}
// _datas, err := c.hashRing.Datas(ctx, nodeId)
_datas, err := c.getAllDatas(ctx, nodeId, counts)
if err != nil {
return err
}
if len(_datas) == 0 {
return nil
}
var migratorFns []func()
for i := 0; i < counts; i++ {
vNodeId := c.getVirNodeId(nodeId, i)
sum := c.encrytor.Hash(vNodeId)
datas, from, to, err := c.migratorOut(ctx, sum, vNodeId, &_datas)
if err != nil {
return err
}
if err := c.hashRing.Del(ctx, sum, vNodeId); err != nil {
return err
}
if len(datas) == 0 || to == "" {
continue
}
migratorFns = append(migratorFns, func() {
_ = c.migrator(ctx, datas, from, to)
})
}
if len(_datas) != 0 {
fmt.Println(_datas)
return errors.New("internal error")
}
if err := c.hashRing.DelNodes(ctx, nodeId); err != nil {
return err
}
c.migratorHandle(migratorFns)
return nil
}
func (c *ConsistenceHash) GetNode(ctx context.Context, key string) (string, error) {
if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil {
return "", err
}
defer func() {
_ = c.hashRing.UnLock(ctx)
}()
keysum := c.encrytor.Hash(key)
nextScore, err := c.hashRing.Ceiling(ctx, keysum)
if err != nil {
return "", err
}
if nextScore == -1 {
return "", errors.New("no avalible node")
}
nextNode, err := c.hashRing.Node(ctx, nextScore)
if err != nil {
return "", err
}
if len(nextNode) == 0 {
return "", errors.New("node id error")
}
if err := c.hashRing.AddDatas(ctx, nextNode[0], map[string]struct{}{
key: {},
}); err != nil {
return "", err
}
return c.getRawNodeId(nextNode[0]), nil
}