package commander import ( "bytes" "context" "errors" "fmt" "sync" "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" ) const tracerName = "git.derfenix.pro/fenix/commander" var ErrNoErrorChannel = errors.New("no error channel provided") type Command interface { Execute(ctx context.Context) error Rollback(ctx context.Context) error } type CorrelatedCommand interface { Command CorrelationID() string } type Eventer interface { EmitEvent(ctx context.Context) error } type CommandCache interface { CommandCacheGet(context.Context, string, ...Command) bool CommandCacheStore(context.Context, string, ...Command) } func mustNewMetrics() metrics { m, err := newMetrics() if err != nil { panic(err) } return m } func newMetrics() (metrics, error) { meter := otel.GetMeterProvider().Meter("executor") commandsCount, err := meter.Int64Histogram("commands", metric.WithDescription("Count of executed commands (can be rolled back)")) if err != nil { return metrics{}, fmt.Errorf("commands count histogram: %w", err) } commandsRollbackCount, err := meter.Int64Histogram("commands.rollback", metric.WithDescription("Count of commands rolled back")) if err != nil { return metrics{}, fmt.Errorf("commands count histogram: %w", err) } commandsFailedCount, err := meter.Int64Histogram("commands.failed", metric.WithDescription("Count of failed commands")) if err != nil { return metrics{}, fmt.Errorf("commands failed count histogram: %w", err) } commandsRollbackFailedCount, err := meter.Int64Histogram("commands.rollback.failed", metric.WithDescription("Count of commands fail to roll back")) if err != nil { return metrics{}, fmt.Errorf("commands count histogram: %w", err) } monitorCommands, err := meter.Int64UpDownCounter("commands.running", metric.WithDescription("Command set in progress")) if err != nil { return metrics{}, fmt.Errorf("monitor commands counter: %w", err) } cacheHit, err := meter.Int64Counter("cache.hit", metric.WithDescription("Cache hit")) if err != nil { return metrics{}, fmt.Errorf("cache hit: %w", err) } cacheMiss, err := meter.Int64Counter("cache.miss", metric.WithDescription("Cache miss")) if err != nil { return metrics{}, fmt.Errorf("cache miss: %w", err) } return metrics{ commandsCount: commandsCount, commandsRollbackCount: commandsRollbackCount, commandsFailedCount: commandsFailedCount, commandsRollbackFailedCount: commandsRollbackFailedCount, monitorCommands: monitorCommands, cacheHit: cacheHit, cacheMiss: cacheMiss, }, nil } type metrics struct { commandsCount metric.Int64Histogram commandsRollbackCount metric.Int64Histogram commandsFailedCount metric.Int64Histogram commandsRollbackFailedCount metric.Int64Histogram monitorCommands metric.Int64UpDownCounter cacheHit metric.Int64Counter cacheMiss metric.Int64Counter } func New(commandsLimit int) *Commander { tracer := otel.Tracer(tracerName) return &Commander{ metrics: mustNewMetrics(), tracer: tracer, correlationIDBuffer: sync.Pool{New: func() any { return bytes.NewBuffer(nil) }}, commandsSemaphore: NewSemaphore(commandsLimit), } } type Commander struct { metrics tracer trace.Tracer cache CommandCache correlationIDBuffer sync.Pool commandsSemaphore Semaphore } func (c *Commander) WithCache(cache Cache) *Commander { c.cache = newCommandCache(cache, cacheTTL) return c } func (c *Commander) WithCacheTTL(cache Cache, ttl time.Duration) *Commander { c.cache = newCommandCache(cache, ttl) return c } func (c *Commander) ExecuteAsync(ctx context.Context, errCh chan<- error, correlationID string, commands ...Command) { ctx, span := c.tracer.Start(ctx, "executor.execute_async") defer span.End() if errCh == nil { span.RecordError(ErrNoErrorChannel) } defer func() { if errCh != nil { close(errCh) } }() if err := c.Execute(ctx, correlationID, commands...); err != nil { if errCh != nil { errCh <- err } } } func (c *Commander) Execute(ctx context.Context, correlationID string, commands ...Command) error { ctx, span := c.tracer.Start( ctx, "executor.execute", trace.WithAttributes( attribute.Int("commands.count", len(commands)), ), ) defer span.End() if c.cache != nil && correlationID == "" { correlationID = c.tryGetCorrelationID(commands) } shouldCache := c.cache != nil && correlationID != "" c.commandsSemaphore.Acquire() c.monitorCommands.Add(ctx, 1) defer func() { c.commandsSemaphore.Release() c.monitorCommands.Add(ctx, -1) }() if shouldCache { if c.cache.CommandCacheGet(ctx, correlationID, commands...) { c.cacheHit.Add(ctx, 1) return nil } else { c.cacheMiss.Add(ctx, 1) } } var actionsCount int64 defer func() { if actionsCount > 0 { c.commandsCount.Record(ctx, actionsCount) } }() eventEmitters := make([]Eventer, 0, len(commands)) for idx, command := range commands { if err := command.Execute(ctx); err != nil { c.commandsFailedCount.Record(ctx, 1) span.RecordError( err, trace.WithAttributes(attribute.String("step", "execute")), trace.WithAttributes(attribute.String("command", fmt.Sprintf("%T", command))), ) span.SetStatus(codes.Error, "failed to execute command") if idx > 0 { c.Rollback(ctx, commands[:idx]...) } return fmt.Errorf("execute command %v: %w", command, err) } if eventer, ok := command.(Eventer); ok { eventEmitters = append(eventEmitters, eventer) } actionsCount++ } for _, emitter := range eventEmitters { if err := emitter.EmitEvent(ctx); err != nil { span.RecordError( err, trace.WithAttributes(attribute.String("step", "send events")), trace.WithAttributes(attribute.String("command", fmt.Sprintf("%T", emitter))), ) } } if shouldCache { c.cache.CommandCacheStore(ctx, correlationID, commands...) } return nil } func (c *Commander) tryGetCorrelationID(commands []Command) string { newCorrelationID := c.correlationIDBuffer.Get().(*bytes.Buffer) defer func() { newCorrelationID.Reset() c.correlationIDBuffer.Put(newCorrelationID) }() for _, command := range commands { correlationIDer, ok := command.(CorrelatedCommand) if !ok { return "" } newCorrelationID.WriteString(correlationIDer.CorrelationID()) } return newCorrelationID.String() } func (c *Commander) Rollback(ctx context.Context, commands ...Command) { ctx, span := c.tracer.Start(ctx, "executor.rollback") defer span.End() for _, command := range commands { if err := command.Rollback(ctx); err != nil { c.commandsRollbackFailedCount.Record(ctx, 1) span.RecordError( err, trace.WithAttributes(attribute.String("step", "rollback")), trace.WithAttributes(attribute.String("command", fmt.Sprintf("%T", command))), ) continue } c.commandsRollbackCount.Record(ctx, 1) } }