cache sharding, valid jwt caching, add metrics

This commit is contained in:
Sergey Chubaryan 2024-08-28 00:58:19 +03:00
parent 0fdc2400ae
commit ed1f8b8c3f
8 changed files with 137 additions and 34 deletions

View File

@ -114,8 +114,21 @@ func (a *App) Run(p RunParams) {
userRepo = repos.NewUserRepo(sqlDb)
emailRepo = repos.NewEmailRepo()
actionTokenRepo = repos.NewActionTokenRepo(sqlDb)
linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60)
userCache = repos.NewCacheInmem[string, models.UserDTO](60 * 60)
jwtCache = repos.NewCacheInmemSharded[string, string](60, 36, func(key string) int {
char := int(key[len(key)-1])
if char >= 0x30 && char <= 0x39 {
return char - 0x30
}
if char >= 0x41 && char <= 0x5A {
return char - 0x41
}
return char - 0x61
}) //repos.NewCacheInmem[string, string](60)
linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60)
userCache = repos.NewCacheInmemSharded[string, models.UserDTO](60*60, 10, func(key string) int {
char := int(key[len(key)-1])
return char - 0x30
}) //repos.NewCacheInmem[string, models.UserDTO](60 * 60)
)
// Periodically trigger cache cleanup
@ -128,6 +141,7 @@ func (a *App) Run(p RunParams) {
case <-ctx.Done():
return
case <-tmr.C:
jwtCache.CheckExpired()
userCache.CheckExpired()
linksCache.CheckExpired()
}
@ -140,6 +154,7 @@ func (a *App) Run(p RunParams) {
Password: passwordUtil,
UserRepo: userRepo,
UserCache: userCache,
JwtCache: jwtCache,
EmailRepo: emailRepo,
ActionTokenRepo: actionTokenRepo,
},

View File

@ -0,0 +1,55 @@
package repos
import (
"sync"
)
func NewCacheInmemSharded[K comparable, V any](
ttlSeconds, shards int,
hashFunc func(key K) int,
) Cache[K, V] {
inmems := []*cacheInmem[K, V]{}
for i := 0; i < shards; i++ {
inmems = append(
inmems,
&cacheInmem[K, V]{
m: &sync.Mutex{},
data: map[K]*cacheInmemItem[V]{},
ttlSeconds: ttlSeconds,
},
)
}
return &cacheInmemSharded[K, V]{
shards: inmems,
hashFunc: hashFunc,
}
}
type cacheInmemSharded[K comparable, V any] struct {
hashFunc func(key K) int
shards []*cacheInmem[K, V]
}
func (c *cacheInmemSharded[K, V]) Get(key K) (V, bool) {
return c.getShard(key).Get(key)
}
func (c *cacheInmemSharded[K, V]) Set(key K, value V, ttlSeconds int) {
c.getShard(key).Set(key, value, ttlSeconds)
}
func (c *cacheInmemSharded[K, V]) Del(key K) {
c.getShard(key).Del(key)
}
func (c *cacheInmemSharded[K, V]) CheckExpired() {
for _, shard := range c.shards {
shard.CheckExpired()
}
}
func (c *cacheInmemSharded[K, V]) getShard(key K) *cacheInmem[K, V] {
index := c.hashFunc(key)
return c.shards[index]
}

View File

@ -34,6 +34,7 @@ type UserServiceDeps struct {
Password utils.PasswordUtil
UserRepo repos.UserRepo
UserCache repos.Cache[string, models.UserDTO]
JwtCache repos.Cache[string, string]
EmailRepo repos.EmailRepo
ActionTokenRepo repos.ActionTokenRepo
}
@ -195,6 +196,10 @@ func (u *userService) getUserById(ctx context.Context, userId string) (*models.U
}
func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*models.UserDTO, error) {
if userId, ok := u.deps.JwtCache.Get(tokenStr); ok {
return u.getUserById(ctx, userId)
}
payload, err := u.deps.Jwt.Parse(tokenStr)
if err != nil {
return nil, ErrUserWrongToken
@ -205,5 +210,7 @@ func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*mode
return nil, err
}
u.deps.JwtCache.Set(tokenStr, payload.UserId, -1)
return user, nil
}

View File

@ -11,14 +11,14 @@ type JwtPayload struct {
UserId string `json:"userId"`
}
type jwtClaims struct {
type JwtClaims struct {
jwt.RegisteredClaims
JwtPayload
}
type JwtUtil interface {
Create(payload JwtPayload) (string, error)
Parse(tokenStr string) (JwtPayload, error)
Parse(tokenStr string) (JwtClaims, error)
}
func NewJwtUtil(privateKey *rsa.PrivateKey) JwtUtil {
@ -32,7 +32,7 @@ type jwtUtil struct {
}
func (j *jwtUtil) Create(payload JwtPayload) (string, error) {
claims := &jwtClaims{JwtPayload: payload}
claims := &JwtClaims{JwtPayload: payload}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
tokenStr, err := token.SignedString(j.privateKey)
if err != nil {
@ -41,17 +41,17 @@ func (j *jwtUtil) Create(payload JwtPayload) (string, error) {
return tokenStr, nil
}
func (j *jwtUtil) Parse(tokenStr string) (JwtPayload, error) {
token, err := jwt.ParseWithClaims(tokenStr, &jwtClaims{}, func(t *jwt.Token) (interface{}, error) {
func (j *jwtUtil) Parse(tokenStr string) (JwtClaims, error) {
token, err := jwt.ParseWithClaims(tokenStr, &JwtClaims{}, func(t *jwt.Token) (interface{}, error) {
return &j.privateKey.PublicKey, nil
})
if err != nil {
return JwtPayload{}, err
return JwtClaims{}, err
}
if claims, ok := token.Claims.(*jwtClaims); ok {
return claims.JwtPayload, nil
if claims, ok := token.Claims.(*JwtClaims); ok {
return *claims, nil
}
return JwtPayload{}, fmt.Errorf("cant get payload")
return JwtClaims{}, fmt.Errorf("cant get payload")
}

View File

@ -25,7 +25,7 @@ type passwordUtil struct {
}
func (b *passwordUtil) Hash(password string) (string, error) {
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
bytes, _ := bcrypt.GenerateFromPassword([]byte(password), 8) //bcrypt.DefaultCost)
return string(bytes), nil
}

View File

@ -9,10 +9,12 @@ import (
)
type Prometheus struct {
reg *prometheus.Registry
rpsCounter prometheus.Counter
avgReqTimeHist prometheus.Histogram
panicsHist prometheus.Histogram
reg *prometheus.Registry
rpsCounter prometheus.Counter
avgReqTimeHist prometheus.Histogram
panicsHist prometheus.Histogram
errors4xxCounter prometheus.Counter
errors5xxCounter prometheus.Counter
}
func NewPrometheus() *Prometheus {
@ -24,12 +26,18 @@ func NewPrometheus() *Prometheus {
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
)
// errorsCounter := prometheus.NewCounter(
// prometheus.CounterOpts{
// Name: "backend_errors_count",
// Help: "Summary errors count",
// },
// )
errors5xxCounter := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "backend_errors_count_5xx",
Help: "5xx errors count",
},
)
errors4xxCounter := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "backend_errors_count_4xx",
Help: "4xx errors count",
},
)
rpsCounter := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "backend_requests_per_second",
@ -48,13 +56,15 @@ func NewPrometheus() *Prometheus {
Help: "Panics histogram metric",
},
)
reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist)
reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist, errors4xxCounter, errors5xxCounter)
return &Prometheus{
panicsHist: panicsHist,
avgReqTimeHist: avgReqTimeHist,
rpsCounter: rpsCounter,
reg: reg,
panicsHist: panicsHist,
avgReqTimeHist: avgReqTimeHist,
rpsCounter: rpsCounter,
errors4xxCounter: errors4xxCounter,
errors5xxCounter: errors5xxCounter,
reg: reg,
}
}
@ -77,3 +87,11 @@ func (p *Prometheus) AddRequestTime(reqTime float64) {
func (p *Prometheus) AddPanic() {
p.panicsHist.Observe(1)
}
func (p *Prometheus) Add4xxError() {
p.errors4xxCounter.Inc()
}
func (p *Prometheus) Add5xxError() {
p.errors5xxCounter.Inc()
}

View File

@ -21,7 +21,7 @@ func newWrapper(writer io.Writer) *bufioWrapper {
ticker.Stop()
return &bufioWrapper{
writer: bufio.NewWriterSize(writer, 128*1024),
writer: bufio.NewWriterSize(writer, 512*1024),
mutex: &sync.RWMutex{},
ticker: ticker,
}
@ -47,8 +47,8 @@ func (b *bufioWrapper) FlushRoutine(ctx context.Context) {
func (b *bufioWrapper) Write(p []byte) (nn int, err error) {
// TODO: try replace mutex, improve logging perfomance
b.mutex.RLock()
defer b.mutex.RUnlock()
b.mutex.Lock()
defer b.mutex.Unlock()
if len(p) > b.writer.Available() {
b.ticker.Reset(FlushInterval)

View File

@ -34,12 +34,20 @@ func NewRequestLogMiddleware(logger log.Logger, prometheus *integrations.Prometh
method := c.Request.Method
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
if statusCode >= 200 && statusCode < 400 {
return
}
ctxLogger := logger.WithContext(c)
e := ctxLogger.Log()
e.Str("ip", clientIP)
e.Msgf("Request %s %s %d %v", method, path, statusCode, latency)
if statusCode >= 400 && statusCode < 500 {
prometheus.Add4xxError()
ctxLogger.Warning().Msgf("Request %s %s %d %v", method, path, statusCode, latency)
return
}
prometheus.Add5xxError()
ctxLogger.Error().Msgf("Request %s %s %d %v", method, path, statusCode, latency)
}
}