cache sharding, valid jwt caching, add metrics

This commit is contained in:
Sergey Chubaryan 2024-08-28 00:58:19 +03:00
parent 0fdc2400ae
commit ed1f8b8c3f
8 changed files with 137 additions and 34 deletions

View File

@ -114,8 +114,21 @@ func (a *App) Run(p RunParams) {
userRepo = repos.NewUserRepo(sqlDb) userRepo = repos.NewUserRepo(sqlDb)
emailRepo = repos.NewEmailRepo() emailRepo = repos.NewEmailRepo()
actionTokenRepo = repos.NewActionTokenRepo(sqlDb) actionTokenRepo = repos.NewActionTokenRepo(sqlDb)
linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) jwtCache = repos.NewCacheInmemSharded[string, string](60, 36, func(key string) int {
userCache = repos.NewCacheInmem[string, models.UserDTO](60 * 60) char := int(key[len(key)-1])
if char >= 0x30 && char <= 0x39 {
return char - 0x30
}
if char >= 0x41 && char <= 0x5A {
return char - 0x41
}
return char - 0x61
}) //repos.NewCacheInmem[string, string](60)
linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60)
userCache = repos.NewCacheInmemSharded[string, models.UserDTO](60*60, 10, func(key string) int {
char := int(key[len(key)-1])
return char - 0x30
}) //repos.NewCacheInmem[string, models.UserDTO](60 * 60)
) )
// Periodically trigger cache cleanup // Periodically trigger cache cleanup
@ -128,6 +141,7 @@ func (a *App) Run(p RunParams) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-tmr.C: case <-tmr.C:
jwtCache.CheckExpired()
userCache.CheckExpired() userCache.CheckExpired()
linksCache.CheckExpired() linksCache.CheckExpired()
} }
@ -140,6 +154,7 @@ func (a *App) Run(p RunParams) {
Password: passwordUtil, Password: passwordUtil,
UserRepo: userRepo, UserRepo: userRepo,
UserCache: userCache, UserCache: userCache,
JwtCache: jwtCache,
EmailRepo: emailRepo, EmailRepo: emailRepo,
ActionTokenRepo: actionTokenRepo, ActionTokenRepo: actionTokenRepo,
}, },

View File

@ -0,0 +1,55 @@
package repos
import (
"sync"
)
func NewCacheInmemSharded[K comparable, V any](
ttlSeconds, shards int,
hashFunc func(key K) int,
) Cache[K, V] {
inmems := []*cacheInmem[K, V]{}
for i := 0; i < shards; i++ {
inmems = append(
inmems,
&cacheInmem[K, V]{
m: &sync.Mutex{},
data: map[K]*cacheInmemItem[V]{},
ttlSeconds: ttlSeconds,
},
)
}
return &cacheInmemSharded[K, V]{
shards: inmems,
hashFunc: hashFunc,
}
}
type cacheInmemSharded[K comparable, V any] struct {
hashFunc func(key K) int
shards []*cacheInmem[K, V]
}
func (c *cacheInmemSharded[K, V]) Get(key K) (V, bool) {
return c.getShard(key).Get(key)
}
func (c *cacheInmemSharded[K, V]) Set(key K, value V, ttlSeconds int) {
c.getShard(key).Set(key, value, ttlSeconds)
}
func (c *cacheInmemSharded[K, V]) Del(key K) {
c.getShard(key).Del(key)
}
func (c *cacheInmemSharded[K, V]) CheckExpired() {
for _, shard := range c.shards {
shard.CheckExpired()
}
}
func (c *cacheInmemSharded[K, V]) getShard(key K) *cacheInmem[K, V] {
index := c.hashFunc(key)
return c.shards[index]
}

View File

@ -34,6 +34,7 @@ type UserServiceDeps struct {
Password utils.PasswordUtil Password utils.PasswordUtil
UserRepo repos.UserRepo UserRepo repos.UserRepo
UserCache repos.Cache[string, models.UserDTO] UserCache repos.Cache[string, models.UserDTO]
JwtCache repos.Cache[string, string]
EmailRepo repos.EmailRepo EmailRepo repos.EmailRepo
ActionTokenRepo repos.ActionTokenRepo ActionTokenRepo repos.ActionTokenRepo
} }
@ -195,6 +196,10 @@ func (u *userService) getUserById(ctx context.Context, userId string) (*models.U
} }
func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*models.UserDTO, error) { func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*models.UserDTO, error) {
if userId, ok := u.deps.JwtCache.Get(tokenStr); ok {
return u.getUserById(ctx, userId)
}
payload, err := u.deps.Jwt.Parse(tokenStr) payload, err := u.deps.Jwt.Parse(tokenStr)
if err != nil { if err != nil {
return nil, ErrUserWrongToken return nil, ErrUserWrongToken
@ -205,5 +210,7 @@ func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*mode
return nil, err return nil, err
} }
u.deps.JwtCache.Set(tokenStr, payload.UserId, -1)
return user, nil return user, nil
} }

View File

@ -11,14 +11,14 @@ type JwtPayload struct {
UserId string `json:"userId"` UserId string `json:"userId"`
} }
type jwtClaims struct { type JwtClaims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
JwtPayload JwtPayload
} }
type JwtUtil interface { type JwtUtil interface {
Create(payload JwtPayload) (string, error) Create(payload JwtPayload) (string, error)
Parse(tokenStr string) (JwtPayload, error) Parse(tokenStr string) (JwtClaims, error)
} }
func NewJwtUtil(privateKey *rsa.PrivateKey) JwtUtil { func NewJwtUtil(privateKey *rsa.PrivateKey) JwtUtil {
@ -32,7 +32,7 @@ type jwtUtil struct {
} }
func (j *jwtUtil) Create(payload JwtPayload) (string, error) { func (j *jwtUtil) Create(payload JwtPayload) (string, error) {
claims := &jwtClaims{JwtPayload: payload} claims := &JwtClaims{JwtPayload: payload}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenStr, err := token.SignedString(j.privateKey) tokenStr, err := token.SignedString(j.privateKey)
if err != nil { if err != nil {
@ -41,17 +41,17 @@ func (j *jwtUtil) Create(payload JwtPayload) (string, error) {
return tokenStr, nil return tokenStr, nil
} }
func (j *jwtUtil) Parse(tokenStr string) (JwtPayload, error) { func (j *jwtUtil) Parse(tokenStr string) (JwtClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &jwtClaims{}, func(t *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenStr, &JwtClaims{}, func(t *jwt.Token) (interface{}, error) {
return &j.privateKey.PublicKey, nil return &j.privateKey.PublicKey, nil
}) })
if err != nil { if err != nil {
return JwtPayload{}, err return JwtClaims{}, err
} }
if claims, ok := token.Claims.(*jwtClaims); ok { if claims, ok := token.Claims.(*JwtClaims); ok {
return claims.JwtPayload, nil return *claims, nil
} }
return JwtPayload{}, fmt.Errorf("cant get payload") return JwtClaims{}, fmt.Errorf("cant get payload")
} }

View File

@ -25,7 +25,7 @@ type passwordUtil struct {
} }
func (b *passwordUtil) Hash(password string) (string, error) { func (b *passwordUtil) Hash(password string) (string, error) {
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) bytes, _ := bcrypt.GenerateFromPassword([]byte(password), 8) //bcrypt.DefaultCost)
return string(bytes), nil return string(bytes), nil
} }

View File

@ -9,10 +9,12 @@ import (
) )
type Prometheus struct { type Prometheus struct {
reg *prometheus.Registry reg *prometheus.Registry
rpsCounter prometheus.Counter rpsCounter prometheus.Counter
avgReqTimeHist prometheus.Histogram avgReqTimeHist prometheus.Histogram
panicsHist prometheus.Histogram panicsHist prometheus.Histogram
errors4xxCounter prometheus.Counter
errors5xxCounter prometheus.Counter
} }
func NewPrometheus() *Prometheus { func NewPrometheus() *Prometheus {
@ -24,12 +26,18 @@ func NewPrometheus() *Prometheus {
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
) )
// errorsCounter := prometheus.NewCounter( errors5xxCounter := prometheus.NewCounter(
// prometheus.CounterOpts{ prometheus.CounterOpts{
// Name: "backend_errors_count", Name: "backend_errors_count_5xx",
// Help: "Summary errors count", Help: "5xx errors count",
// }, },
// ) )
errors4xxCounter := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "backend_errors_count_4xx",
Help: "4xx errors count",
},
)
rpsCounter := prometheus.NewCounter( rpsCounter := prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "backend_requests_per_second", Name: "backend_requests_per_second",
@ -48,13 +56,15 @@ func NewPrometheus() *Prometheus {
Help: "Panics histogram metric", Help: "Panics histogram metric",
}, },
) )
reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist) reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist, errors4xxCounter, errors5xxCounter)
return &Prometheus{ return &Prometheus{
panicsHist: panicsHist, panicsHist: panicsHist,
avgReqTimeHist: avgReqTimeHist, avgReqTimeHist: avgReqTimeHist,
rpsCounter: rpsCounter, rpsCounter: rpsCounter,
reg: reg, errors4xxCounter: errors4xxCounter,
errors5xxCounter: errors5xxCounter,
reg: reg,
} }
} }
@ -77,3 +87,11 @@ func (p *Prometheus) AddRequestTime(reqTime float64) {
func (p *Prometheus) AddPanic() { func (p *Prometheus) AddPanic() {
p.panicsHist.Observe(1) p.panicsHist.Observe(1)
} }
func (p *Prometheus) Add4xxError() {
p.errors4xxCounter.Inc()
}
func (p *Prometheus) Add5xxError() {
p.errors5xxCounter.Inc()
}

View File

@ -21,7 +21,7 @@ func newWrapper(writer io.Writer) *bufioWrapper {
ticker.Stop() ticker.Stop()
return &bufioWrapper{ return &bufioWrapper{
writer: bufio.NewWriterSize(writer, 128*1024), writer: bufio.NewWriterSize(writer, 512*1024),
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
ticker: ticker, ticker: ticker,
} }
@ -47,8 +47,8 @@ func (b *bufioWrapper) FlushRoutine(ctx context.Context) {
func (b *bufioWrapper) Write(p []byte) (nn int, err error) { func (b *bufioWrapper) Write(p []byte) (nn int, err error) {
// TODO: try replace mutex, improve logging perfomance // TODO: try replace mutex, improve logging perfomance
b.mutex.RLock() b.mutex.Lock()
defer b.mutex.RUnlock() defer b.mutex.Unlock()
if len(p) > b.writer.Available() { if len(p) > b.writer.Available() {
b.ticker.Reset(FlushInterval) b.ticker.Reset(FlushInterval)

View File

@ -34,12 +34,20 @@ func NewRequestLogMiddleware(logger log.Logger, prometheus *integrations.Prometh
method := c.Request.Method method := c.Request.Method
statusCode := c.Writer.Status() statusCode := c.Writer.Status()
clientIP := c.ClientIP()
if statusCode >= 200 && statusCode < 400 {
return
}
ctxLogger := logger.WithContext(c) ctxLogger := logger.WithContext(c)
e := ctxLogger.Log() if statusCode >= 400 && statusCode < 500 {
e.Str("ip", clientIP) prometheus.Add4xxError()
e.Msgf("Request %s %s %d %v", method, path, statusCode, latency) ctxLogger.Warning().Msgf("Request %s %s %d %v", method, path, statusCode, latency)
return
}
prometheus.Add5xxError()
ctxLogger.Error().Msgf("Request %s %s %d %v", method, path, statusCode, latency)
} }
} }