From 2c33db0b7829587f957a61222baae7e1f772192b Mon Sep 17 00:00:00 2001 From: Sergey Chubaryan Date: Wed, 28 Aug 2024 20:47:19 +0300 Subject: [PATCH] rework cache, improve setting ttl --- src/app.go | 16 +++-- src/cache/cache_inmem.go | 99 ++++++++++++++++++++++++++ src/cache/cache_inmem_shard.go | 58 +++++++++++++++ src/cache/interface.go | 25 +++++++ src/cache/sharding_info.go | 47 ++++++++++++ src/charsets/charsets.go | 36 +++++----- src/core/repos/cache_inmem.go | 99 -------------------------- src/core/repos/cache_inmem_shard.go | 88 ----------------------- src/core/services/shortlink_service.go | 8 +-- src/core/services/user_service.go | 20 ++++-- src/core/utils/jwt.go | 8 ++- 11 files changed, 280 insertions(+), 224 deletions(-) create mode 100644 src/cache/cache_inmem.go create mode 100644 src/cache/cache_inmem_shard.go create mode 100644 src/cache/interface.go create mode 100644 src/cache/sharding_info.go delete mode 100644 src/core/repos/cache_inmem.go delete mode 100644 src/core/repos/cache_inmem_shard.go diff --git a/src/app.go b/src/app.go index 269755b..6c0e4f2 100644 --- a/src/app.go +++ b/src/app.go @@ -2,6 +2,7 @@ package src import ( "backend/src/args_parser" + "backend/src/cache" "backend/src/client_notifier" "backend/src/config" "backend/src/core/models" @@ -114,9 +115,10 @@ func (a *App) Run(p RunParams) { userRepo = repos.NewUserRepo(sqlDb) emailRepo = repos.NewEmailRepo() actionTokenRepo = repos.NewActionTokenRepo(sqlDb) - userCache = repos.NewCacheInmemSharded[models.UserDTO](60*60, repos.ShardingTypeInteger) - jwtCache = repos.NewCacheInmemSharded[string](60, repos.ShardingTypeJWT) - linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) + + userCache = cache.NewCacheInmemSharded[models.UserDTO](cache.ShardingTypeInteger) + jwtCache = cache.NewCacheInmemSharded[string](cache.ShardingTypeJWT) + linksCache = cache.NewCacheInmem[string, string]() ) // Periodically trigger cache cleanup @@ -124,14 +126,16 @@ func (a *App) Run(p RunParams) { tmr := time.NewTicker(5 * time.Minute) defer tmr.Stop() + batchSize := 100 + for { select { case <-ctx.Done(): return case <-tmr.C: - userCache.CheckExpired() - jwtCache.CheckExpired() - linksCache.CheckExpired() + userCache.CheckExpired(batchSize) + jwtCache.CheckExpired(batchSize) + linksCache.CheckExpired(batchSize) } } }() diff --git a/src/cache/cache_inmem.go b/src/cache/cache_inmem.go new file mode 100644 index 0000000..ea39127 --- /dev/null +++ b/src/cache/cache_inmem.go @@ -0,0 +1,99 @@ +package cache + +import ( + "sync" + "time" +) + +func NewCacheInmem[K comparable, V any]() Cache[K, V] { + return &cacheInmem[K, V]{ + m: &sync.RWMutex{}, + data: map[K]*cacheInmemItem[V]{}, + } +} + +type cacheInmemItem[T any] struct { + Value T + Expiration time.Time +} + +type cacheInmem[K comparable, V any] struct { + m *sync.RWMutex + data map[K]*cacheInmemItem[V] +} + +func (c *cacheInmem[K, V]) Get(key K) (V, bool) { + c.m.RLock() + defer c.m.RUnlock() + + var v V + + item, ok := c.data[key] + if !ok { + return v, false + } + if time.Now().Before(item.Expiration) { + return item.Value, true + } + + return v, false +} + +func (c *cacheInmem[K, V]) GetEx(key K, exp Expiration) (V, bool) { + c.m.Lock() + defer c.m.Unlock() + + item, ok := c.data[key] + if !ok { + var v V + return v, false + } + + if time.Now().Before(item.Expiration) { + c.data[key].Expiration = exp.Get() + return item.Value, true + } + + delete(c.data, key) + + var v V + return v, false +} + +func (c *cacheInmem[K, V]) Set(key K, value V, exp Expiration) { + c.m.Lock() + defer c.m.Unlock() + + item := &cacheInmemItem[V]{ + Value: value, + Expiration: exp.Get(), + } + c.data[key] = item +} + +func (c *cacheInmem[K, V]) Del(key K) { + c.m.Lock() + defer c.m.Unlock() + + delete(c.data, key) +} + +func (c *cacheInmem[K, V]) CheckExpired(batchSize int) { + if len(c.data) == 0 { + return + } + + c.m.Lock() + defer c.m.Unlock() + + for key, item := range c.data { + if time.Now().After(item.Expiration) { + delete(c.data, key) + } + + batchSize-- + if batchSize <= 0 { + return + } + } +} diff --git a/src/cache/cache_inmem_shard.go b/src/cache/cache_inmem_shard.go new file mode 100644 index 0000000..72b8f36 --- /dev/null +++ b/src/cache/cache_inmem_shard.go @@ -0,0 +1,58 @@ +package cache + +import ( + "sync" +) + +func NewCacheInmemSharded[V any](shardingType ShardingType) Cache[string, V] { + info := getShardingInfo(shardingType) + + shards := []*cacheInmem[string, V]{} + for i := 0; i < info.Shards; i++ { + shards = append( + shards, + &cacheInmem[string, V]{ + m: &sync.RWMutex{}, + data: map[string]*cacheInmemItem[V]{}, + }, + ) + } + + return &cacheInmemSharded[V]{ + info: info, + shards: shards, + } +} + +type cacheInmemSharded[V any] struct { + info ShardingInfo + shards []*cacheInmem[string, V] +} + +func (c *cacheInmemSharded[V]) Get(key string) (V, bool) { + return c.getShard(key).Get(key) +} + +func (c *cacheInmemSharded[V]) GetEx(key string, exp Expiration) (V, bool) { + return c.getShard(key).GetEx(key, exp) +} + +func (c *cacheInmemSharded[V]) Set(key string, value V, exp Expiration) { + c.getShard(key).Set(key, value, exp) +} + +func (c *cacheInmemSharded[V]) Del(key string) { + c.getShard(key).Del(key) +} + +func (c *cacheInmemSharded[V]) CheckExpired(batchSize int) { + size := batchSize / c.info.Shards + for _, shard := range c.shards { + shard.CheckExpired(size) + } +} + +func (c *cacheInmemSharded[V]) getShard(key string) *cacheInmem[string, V] { + index := c.info.HashFunc(key) + return c.shards[index] +} diff --git a/src/cache/interface.go b/src/cache/interface.go new file mode 100644 index 0000000..75841a3 --- /dev/null +++ b/src/cache/interface.go @@ -0,0 +1,25 @@ +package cache + +import "time" + +type Expiration struct { + Ttl time.Duration + ExpiresAt time.Time +} + +func (e Expiration) Get() time.Time { + if e.Ttl != 0 { + return time.Now().Add(e.Ttl) + } + return e.ExpiresAt +} + +type Cache[K comparable, V any] interface { + Get(key K) (V, bool) + GetEx(key K, exp Expiration) (V, bool) + + Set(key K, value V, exp Expiration) + + Del(key K) + CheckExpired(batchSize int) +} diff --git a/src/cache/sharding_info.go b/src/cache/sharding_info.go new file mode 100644 index 0000000..73cc701 --- /dev/null +++ b/src/cache/sharding_info.go @@ -0,0 +1,47 @@ +package cache + +type ShardingType int + +const ( + ShardingTypeJWT ShardingType = iota + ShardingTypeInteger +) + +type ShardingInfo struct { + Shards int + HashFunc func(key string) int +} + +func getShardingInfo(shardingType ShardingType) ShardingInfo { + switch shardingType { + case ShardingTypeInteger: + return ShardingInfo{ + Shards: 10, + HashFunc: func(key string) int { + char := int(key[len(key)-1]) + return char - 0x30 + }, + } + case ShardingTypeJWT: + return ShardingInfo{ + Shards: 36, + HashFunc: func(key string) int { + char := int(key[len(key)-1]) + if char >= 0x30 && char <= 0x39 { + return char - 0x30 + } + if char >= 0x41 && char <= 0x5A { + return char - 0x41 + } + return char - 0x61 + }, + } + } + + return ShardingInfo{ + Shards: 1, + HashFunc: func(key string) int { + return 0 + }, + } +} diff --git a/src/charsets/charsets.go b/src/charsets/charsets.go index 5d70492..e2ad43f 100644 --- a/src/charsets/charsets.go +++ b/src/charsets/charsets.go @@ -10,8 +10,6 @@ type Charset interface { TestRune(char rune) bool RandomRune(r RandInt) rune RandomString(r RandInt, size int) string - - String() string } func NewCharsetFromASCII(offset, size int) Charset { @@ -40,13 +38,13 @@ func (c charsetASCII) RandomString(r RandInt, size int) string { return builder.String() } -func (c charsetASCII) String() string { - builder := strings.Builder{} - for i := 0; i < c.size; i++ { - builder.WriteRune(rune(c.offset + i)) - } - return builder.String() -} +// func (c charsetASCII) String() string { +// builder := strings.Builder{} +// for i := 0; i < c.size; i++ { +// builder.WriteRune(rune(c.offset + i)) +// } +// return builder.String() +// } func NewCharsetFromString(s string) Charset { charsArray := make([]rune, len(s)) @@ -84,13 +82,13 @@ func (c charsetFromString) RandomString(r RandInt, size int) string { return builder.String() } -func (c charsetFromString) String() string { - builder := strings.Builder{} - for _, v := range c.charsArray { - builder.WriteRune(v) - } - return builder.String() -} +// func (c charsetFromString) String() string { +// builder := strings.Builder{} +// for _, v := range c.charsArray { +// builder.WriteRune(v) +// } +// return builder.String() +// } func NewCharsetUnion(opts ...Charset) Charset { charsets := []Charset{} @@ -130,6 +128,6 @@ func (c charsetUnion) RandomString(r RandInt, size int) string { return builder.String() } -func (c charsetUnion) String() string { - return "" -} +// func (c charsetUnion) String() string { +// return "" +// } diff --git a/src/core/repos/cache_inmem.go b/src/core/repos/cache_inmem.go deleted file mode 100644 index 6503c3e..0000000 --- a/src/core/repos/cache_inmem.go +++ /dev/null @@ -1,99 +0,0 @@ -package repos - -import ( - "sync" - "time" -) - -type Cache[K comparable, V any] interface { - Get(key K) (V, bool) - Set(key K, value V, ttlSeconds int) - Del(key K) - CheckExpired() -} - -func NewCacheInmem[K comparable, V any](ttlSeconds int) Cache[K, V] { - return &cacheInmem[K, V]{ - m: &sync.Mutex{}, - data: map[K]*cacheInmemItem[V]{}, - ttlSeconds: ttlSeconds, - } -} - -type cacheInmemItem[T any] struct { - Value T - Ttl int64 - Expiration int64 -} - -type cacheInmem[K comparable, V any] struct { - m *sync.Mutex - data map[K]*cacheInmemItem[V] - ttlSeconds int -} - -func (c *cacheInmem[K, V]) Get(key K) (V, bool) { - c.m.Lock() - defer c.m.Unlock() - - item, ok := c.data[key] - if !ok { - var v V - return v, false - } - - timestamp := time.Now().Unix() - if item.Expiration > timestamp { - item.Expiration = timestamp + item.Ttl - return item.Value, true - } - - delete(c.data, key) - - var v V - return v, false -} - -func (c *cacheInmem[K, V]) Set(key K, value V, ttlSeconds int) { - c.m.Lock() - defer c.m.Unlock() - - ttl := int64(c.ttlSeconds) - - expiration := time.Now().Unix() + ttl - item := &cacheInmemItem[V]{ - Value: value, - Ttl: ttl, - Expiration: expiration, - } - c.data[key] = item -} - -func (c *cacheInmem[K, V]) Del(key K) { - c.m.Lock() - defer c.m.Unlock() - - delete(c.data, key) -} - -func (c *cacheInmem[K, V]) CheckExpired() { - if len(c.data) == 0 { - return - } - - c.m.Lock() - defer c.m.Unlock() - - itemsToProcess := 1000 - for key, item := range c.data { - timestamp := time.Now().Unix() - if item.Expiration <= timestamp { - delete(c.data, key) - } - - itemsToProcess-- - if itemsToProcess <= 0 { - return - } - } -} diff --git a/src/core/repos/cache_inmem_shard.go b/src/core/repos/cache_inmem_shard.go deleted file mode 100644 index 3d17ea2..0000000 --- a/src/core/repos/cache_inmem_shard.go +++ /dev/null @@ -1,88 +0,0 @@ -package repos - -import ( - "sync" -) - -type ShardingType int - -const ( - ShardingTypeJWT ShardingType = iota - ShardingTypeInteger -) - -type shardingHashFunc func(key string) int - -func getShardingInfo(shardingType ShardingType) (int, shardingHashFunc) { - switch shardingType { - case ShardingTypeInteger: - return 10, func(key string) int { - char := int(key[len(key)-1]) - return char - 0x30 - } - case ShardingTypeJWT: - return 36, func(key string) int { - char := int(key[len(key)-1]) - if char >= 0x30 && char <= 0x39 { - return char - 0x30 - } - if char >= 0x41 && char <= 0x5A { - return char - 0x41 - } - return char - 0x61 - } - } - - return 1, func(key string) int { - return 0 - } -} - -func NewCacheInmemSharded[V any](defaultTtlSeconds int, shardingType ShardingType) Cache[string, V] { - shards, hashFunc := getShardingInfo(shardingType) - - inmems := []*cacheInmem[string, V]{} - for i := 0; i < shards; i++ { - inmems = append( - inmems, - &cacheInmem[string, V]{ - m: &sync.Mutex{}, - data: map[string]*cacheInmemItem[V]{}, - ttlSeconds: defaultTtlSeconds, - }, - ) - } - - return &cacheInmemSharded[V]{ - shards: inmems, - hashFunc: hashFunc, - } -} - -type cacheInmemSharded[V any] struct { - hashFunc shardingHashFunc - shards []*cacheInmem[string, V] -} - -func (c *cacheInmemSharded[V]) Get(key string) (V, bool) { - return c.getShard(key).Get(key) -} - -func (c *cacheInmemSharded[V]) Set(key string, value V, ttlSeconds int) { - c.getShard(key).Set(key, value, ttlSeconds) -} - -func (c *cacheInmemSharded[V]) Del(key string) { - c.getShard(key).Del(key) -} - -func (c *cacheInmemSharded[V]) CheckExpired() { - for _, shard := range c.shards { - shard.CheckExpired() - } -} - -func (c *cacheInmemSharded[V]) getShard(key string) *cacheInmem[string, V] { - index := c.hashFunc(key) - return c.shards[index] -} diff --git a/src/core/services/shortlink_service.go b/src/core/services/shortlink_service.go index 866d27a..ad487be 100644 --- a/src/core/services/shortlink_service.go +++ b/src/core/services/shortlink_service.go @@ -1,8 +1,8 @@ package services import ( + "backend/src/cache" "backend/src/charsets" - "backend/src/core/repos" "fmt" "math/rand" "time" @@ -15,7 +15,7 @@ type ShortlinkService interface { type NewShortlinkServiceParams struct { Endpoint string - Cache repos.Cache[string, string] + Cache cache.Cache[string, string] } func NewShortlinkSevice(params NewShortlinkServiceParams) ShortlinkService { @@ -25,7 +25,7 @@ func NewShortlinkSevice(params NewShortlinkServiceParams) ShortlinkService { } type shortlinkService struct { - cache repos.Cache[string, string] + cache cache.Cache[string, string] } func (s *shortlinkService) CreateLink(in string) (string, error) { @@ -35,7 +35,7 @@ func (s *shortlinkService) CreateLink(in string) (string, error) { randGen := rand.New(src) str := charset.RandomString(randGen, 10) - s.cache.Set(str, in, 7*24*60*60) + s.cache.Set(str, in, cache.Expiration{Ttl: 7 * 24 * time.Hour}) return str, nil } diff --git a/src/core/services/user_service.go b/src/core/services/user_service.go index d841959..e909a57 100644 --- a/src/core/services/user_service.go +++ b/src/core/services/user_service.go @@ -1,11 +1,13 @@ package services import ( + "backend/src/cache" "backend/src/core/models" "backend/src/core/repos" "backend/src/core/utils" "context" "fmt" + "time" "github.com/google/uuid" ) @@ -19,6 +21,10 @@ var ( // ErrUserInternal = fmt.Errorf("unexpected error. contact tech support") ) +const ( + userCacheTtl = time.Hour +) + type UserService interface { CreateUser(ctx context.Context, params UserCreateParams) (*models.UserDTO, error) AuthenticateUser(ctx context.Context, login, password string) (string, error) @@ -33,8 +39,8 @@ type UserServiceDeps struct { Jwt utils.JwtUtil Password utils.PasswordUtil UserRepo repos.UserRepo - UserCache repos.Cache[string, models.UserDTO] - JwtCache repos.Cache[string, string] + UserCache cache.Cache[string, models.UserDTO] + JwtCache cache.Cache[string, string] EmailRepo repos.EmailRepo ActionTokenRepo repos.ActionTokenRepo } @@ -78,7 +84,7 @@ func (u *userService) CreateUser(ctx context.Context, params UserCreateParams) ( return nil, err } - u.deps.UserCache.Set(result.Id, *result, -1) + u.deps.UserCache.Set(result.Id, *result, cache.Expiration{Ttl: userCacheTtl}) return result, nil } @@ -102,7 +108,7 @@ func (u *userService) AuthenticateUser(ctx context.Context, email, password stri return "", err } - u.deps.UserCache.Set(user.Id, *user, -1) + u.deps.UserCache.Set(user.Id, *user, cache.Expiration{Ttl: userCacheTtl}) return jwt, nil } @@ -178,7 +184,7 @@ func (u *userService) updatePassword(ctx context.Context, user models.UserDTO, n } func (u *userService) getUserById(ctx context.Context, userId string) (*models.UserDTO, error) { - if user, ok := u.deps.UserCache.Get(userId); ok { + if user, ok := u.deps.UserCache.GetEx(userId, cache.Expiration{Ttl: userCacheTtl}); ok { return &user, nil } @@ -190,7 +196,7 @@ func (u *userService) getUserById(ctx context.Context, userId string) (*models.U return nil, ErrUserNotExists } - u.deps.UserCache.Set(user.Id, *user, -1) + u.deps.UserCache.Set(user.Id, *user, cache.Expiration{Ttl: userCacheTtl}) return user, nil } @@ -210,7 +216,7 @@ func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*mode return nil, err } - u.deps.JwtCache.Set(tokenStr, payload.UserId, -1) + u.deps.JwtCache.Set(tokenStr, payload.UserId, cache.Expiration{ExpiresAt: payload.ExpiresAt.Time}) return user, nil } diff --git a/src/core/utils/jwt.go b/src/core/utils/jwt.go index 89d9bed..431bc5c 100644 --- a/src/core/utils/jwt.go +++ b/src/core/utils/jwt.go @@ -3,6 +3,7 @@ package utils import ( "crypto/rsa" "fmt" + "time" "github.com/golang-jwt/jwt/v5" ) @@ -32,7 +33,12 @@ type jwtUtil struct { } func (j *jwtUtil) Create(payload JwtPayload) (string, error) { - claims := &JwtClaims{JwtPayload: payload} + claims := &JwtClaims{ + JwtPayload: payload, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + }, + } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenStr, err := token.SignedString(j.privateKey) if err != nil {