package consistencehash import ( "context" "errors" "fmt" "strings" "sync" ) type ConsistenceHash struct { hashRing HashRing encrytor Encrytor migrator Migrator opts Options } func NewConsistenceHash(hashRing HashRing, encrytor Encrytor, migrator Migrator, opts ...Option) *ConsistenceHash { chash := &ConsistenceHash{ hashRing: hashRing, encrytor: encrytor, migrator: migrator, } for _, o := range opts { o(&chash.opts) } chash.opts.repire() return chash } func (c *ConsistenceHash) getVirNodeId(id string, index int) string { return fmt.Sprintf("%s_%d", id, index) } func (c *ConsistenceHash) getRawNodeId(id string) string { index := strings.LastIndex(id, "_") return id[:index] } func (c *ConsistenceHash) migratorHandle(funs []func()) { wg := sync.WaitGroup{} for _, fn := range funs { wg.Add(1) go func(fn func()) { defer func() { if err := recover(); err != nil { fmt.Println("recover") } wg.Done() }() fn() }(fn) } wg.Wait() } func (c *ConsistenceHash) wight(wight int) int { if wight < 1 { return 1 } if wight > 10 { return 10 } return wight } func (c *ConsistenceHash) AddNode(ctx context.Context, nodeId string, wight int) error { if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil { return err } defer func() { _ = c.hashRing.UnLock(ctx) }() nodes, err := c.hashRing.Nodes(ctx) if err != nil { return err } for node := range nodes { if node == nodeId { return errors.New("node already exists") } } count := c.wight(wight) * c.opts.replicas if err := c.hashRing.AddNodes(ctx, nodeId, count); err != nil { return err } var migratorFns []func() for i := 0; i < count; i++ { vNodeId := c.getVirNodeId(nodeId, i) sum := c.encrytor.Hash(vNodeId) if err := c.hashRing.Add(ctx, sum, vNodeId); err != nil { return err } datas, from, to, err := c.migratorIn(ctx, sum, vNodeId) if err != nil { return err } if len(datas) == 0 || from == "" { continue } migratorFns = append(migratorFns, func() { _ = c.migrator(ctx, datas, from, to) }) } c.migratorHandle(migratorFns) return nil } func (c *ConsistenceHash) getAllDatas(ctx context.Context, nodeId string, counts int) (map[string]struct{}, error) { var allDatas map[string]struct{} for i := 0; i < counts; i++ { vNodeId := c.getVirNodeId(nodeId, i) _datas, err := c.hashRing.Datas(ctx, vNodeId) if err != nil { return allDatas, err } if allDatas == nil { allDatas = make(map[string]struct{}) } for key := range _datas { allDatas[key] = struct{}{} } } return allDatas, nil } func (c *ConsistenceHash) DelNode(ctx context.Context, nodeId string) error { if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil { return err } defer func() { _ = c.hashRing.UnLock(ctx) }() nodes, err := c.hashRing.Nodes(ctx) if err != nil { return err } var ( exists bool counts int ) for node, nums := range nodes { if node == nodeId { exists = true counts = nums break } } if !exists { return errors.New("node id is not exists") } // _datas, err := c.hashRing.Datas(ctx, nodeId) _datas, err := c.getAllDatas(ctx, nodeId, counts) if err != nil { return err } if len(_datas) == 0 { return nil } var migratorFns []func() for i := 0; i < counts; i++ { vNodeId := c.getVirNodeId(nodeId, i) sum := c.encrytor.Hash(vNodeId) datas, from, to, err := c.migratorOut(ctx, sum, vNodeId, &_datas) if err != nil { return err } if err := c.hashRing.Del(ctx, sum, vNodeId); err != nil { return err } if len(datas) == 0 || to == "" { continue } migratorFns = append(migratorFns, func() { _ = c.migrator(ctx, datas, from, to) }) } if len(_datas) != 0 { fmt.Println(_datas) return errors.New("internal error") } if err := c.hashRing.DelNodes(ctx, nodeId); err != nil { return err } c.migratorHandle(migratorFns) return nil } func (c *ConsistenceHash) GetNode(ctx context.Context, key string) (string, error) { if err := c.hashRing.Lock(ctx, c.opts.lockExpire); err != nil { return "", err } defer func() { _ = c.hashRing.UnLock(ctx) }() keysum := c.encrytor.Hash(key) nextScore, err := c.hashRing.Ceiling(ctx, keysum) if err != nil { return "", err } if nextScore == -1 { return "", errors.New("no avalible node") } nextNode, err := c.hashRing.Node(ctx, nextScore) if err != nil { return "", err } if len(nextNode) == 0 { return "", errors.New("node id error") } if err := c.hashRing.AddDatas(ctx, nextNode[0], map[string]struct{}{ key: {}, }); err != nil { return "", err } return c.getRawNodeId(nextNode[0]), nil }