Initial commit

This commit is contained in:
2023-08-24 23:40:31 +03:00
commit 49c962e13c
32 changed files with 1360 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
package repository
import (
"database/sql"
"fmt"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
func NewDB(dsn string) (*bun.DB, error) {
connector := pgdriver.NewConnector(pgdriver.WithDSN(dsn))
sqlDB := sql.OpenDB(connector)
sqlDB.SetMaxOpenConns(10)
db := bun.NewDB(sqlDB, pgdialect.New())
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping database: %w", err)
}
return db, nil
}

View File

@@ -0,0 +1,142 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/uptrace/bun"
"go.uber.org/zap"
)
type Cache interface {
Get(id uint64) []string
Append(id uint64, ip string)
}
type ConnLog struct {
bun.BaseModel `bun:"table:conn_log"`
UserID uint64 `bun:"user_id"`
IP string `bun:"ip_addr"`
TS time.Time `bun:"ts"`
}
func NewConnLogs(ctx context.Context, db *bun.DB, cache Cache, logger *zap.Logger, updateInterval time.Duration) (*ConnLogs, error) {
connLogs := &ConnLogs{db: db, cache: cache}
logger.Info("filling initial cache")
err := connLogs.fillCache(ctx)
if err != nil {
return nil, err
}
logger.Info("initial cache filled")
go func() {
ticker := time.NewTicker(updateInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
var err error
err = connLogs.fillCache(ctx)
if err != nil {
logger.Error("update cache", zap.Error(err))
}
}
}
}()
return connLogs, nil
}
type ConnLogs struct {
db *bun.DB
mu sync.RWMutex
cache Cache
lastTS time.Time
}
func (l *ConnLogs) fillCache(ctx context.Context) error {
var entity []ConnLog
l.mu.Lock()
defer l.mu.Unlock()
query := l.db.NewSelect().Model(&entity).
Order("ts").
Group("user_id", "ip_addr").
Column("user_id", "ip_addr").
ColumnExpr(`max("ts") as ts`)
if !l.lastTS.IsZero() {
query.Where(`"ts" > ? `, l.lastTS)
}
if err := query.Scan(ctx); err != nil {
return fmt.Errorf("select: %w", err)
}
loop:
for i := range entity {
item := &entity[i]
if ips := l.cache.Get(item.UserID); len(ips) == 0 {
l.cache.Append(item.UserID, item.IP)
continue
}
for _, s := range l.cache.Get(item.UserID) {
if s == item.IP {
continue loop
}
}
l.cache.Append(item.UserID, item.IP)
}
if len(entity) > 0 {
l.lastTS = entity[len(entity)-1].TS
}
return nil
}
func (l *ConnLogs) Get(_ context.Context, first, second uint64) (bool, error) {
ips1 := l.cache.Get(first)
ips2 := l.cache.Get(second)
if len(ips1) == 0 || len(ips2) == 0 {
return false, nil
}
for i := range ips1 {
for j := range ips2 {
if ips1[i] == ips2[j] {
return true, nil
}
}
}
return false, nil
}
func (l *ConnLogs) List(ctx context.Context) (string, error) {
var entity []ConnLog
if err := l.db.NewSelect().Model(&entity).Scan(ctx); err != nil {
return "", fmt.Errorf("select: %w", err)
}
marshal, err := json.Marshal(entity)
if err != nil {
return "", fmt.Errorf("marshal: %w", err)
}
return string(marshal), nil
}

View File

@@ -0,0 +1,151 @@
package repository_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"go.uber.org/zap/zaptest"
"git.derfenix.pro/fenix/protect_trans_info/adapters/inmemorycache"
"git.derfenix.pro/fenix/protect_trans_info/application"
. "git.derfenix.pro/fenix/protect_trans_info/application/repository"
)
func TestLogs_Get(t *testing.T) {
if testing.Short() {
t.Skip("skip long test")
}
t.Parallel()
dockerPool, err := dockertest.NewPool("")
dockerPool.MaxWait = time.Second * 10
require.NoError(t, err)
resource, err := dockerPool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "15",
Env: []string{
"POSTGRES_USER=test",
"POSTGRES_PASSWORD=test",
"POSTGRES_DB=test",
"POSTGRES_HOST_AUTH_METHOD=md5",
"POSTGRES_INITDB_ARGS=--auth-host=md5",
},
PortBindings: map[docker.Port][]docker.PortBinding{"5432/tcp": {{HostIP: "0.0.0.0", HostPort: "55432"}}},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
})
require.NoError(t, err)
t.Cleanup(func() {
err := dockerPool.Purge(resource)
assert.NoError(t, err)
})
connector := pgdriver.NewConnector(pgdriver.WithDSN("postgresql://test:test@localhost:55432/test?sslmode=disable"))
sqlDB := sql.OpenDB(connector)
sqlDB.SetMaxOpenConns(10)
db := bun.NewDB(sqlDB, pgdialect.New())
err = dockerPool.Retry(func() error {
pingCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
return fmt.Errorf("ping database: %w", err)
}
return nil
})
require.NoError(t, err)
ctx := context.Background()
logger := zaptest.NewLogger(t)
require.NoError(t, application.Migrate(ctx, db, logger))
testData := []ConnLog{
{
UserID: 1,
IP: "123.123.123.123",
TS: time.Now(),
},
{
UserID: 2,
IP: "123.123.123.123",
TS: time.Now().Add(time.Hour),
},
{
UserID: 3,
IP: "124.123.123.123",
TS: time.Now().Add(time.Hour * 2),
},
}
_, err = db.NewInsert().Model(&testData).Exec(ctx)
require.NoError(t, err)
repo, err := NewConnLogs(ctx, db, make(inmemorycache.Cache), logger, time.Millisecond*100)
require.NoError(t, err)
t.Run("found dup", func(t *testing.T) {
t.Parallel()
get, err := repo.Get(ctx, 1, 2)
require.NoError(t, err)
require.True(t, get)
})
t.Run("no dup 1", func(t *testing.T) {
t.Parallel()
get, err := repo.Get(ctx, 1, 3)
require.NoError(t, err)
require.False(t, get)
})
t.Run("no dup 2", func(t *testing.T) {
t.Parallel()
get, err := repo.Get(ctx, 2, 3)
require.NoError(t, err)
require.False(t, get)
})
t.Run("added item", func(t *testing.T) {
t.Parallel()
get, err := repo.Get(ctx, 4, 3)
require.NoError(t, err)
require.False(t, get)
_, err = db.NewInsert().Model(&ConnLog{
BaseModel: bun.BaseModel{},
UserID: 4,
IP: "124.123.123.123",
TS: time.Now().Add(time.Hour * 3),
}).Exec(ctx)
require.NoError(t, err)
get, err = repo.Get(ctx, 4, 3)
require.NoError(t, err)
assert.False(t, get)
time.Sleep(time.Millisecond * 200)
get, err = repo.Get(ctx, 4, 3)
require.NoError(t, err)
require.True(t, get)
})
}