package discovery import ( "context" "errors" "fmt" "net" "strconv" "strings" "sync" "time" "go.uber.org/zap" ) const defaultBroadcastInterval = time.Second type ( KnownNodes map[string]struct{} NewNodes <-chan string ) // NewDiscoverySet returns a set of discovery services for all running and non-loopback network interfaces. func NewDiscoverySet(log *zap.Logger, discoverPort uint16, opts ...Option) (DiscoverySet, error) { iFaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("list interfaces: %w", err) } set := make(DiscoverySet, 0, len(iFaces)) var errs []error for _, iFace := range iFaces { if iFace.Flags&net.FlagLoopback == net.FlagLoopback { continue } if iFace.Flags&net.FlagRunning != net.FlagRunning { continue } discover, err := NewDiscovery(iFace, log, discoverPort, opts...) if err != nil { errs = append(errs, err) continue } set = append(set, discover) } if len(set) == 0 { return nil, errors.Join(errs...) } return set, nil } // NewDiscovery returns new initialized discovery service for specified network interface. func NewDiscovery(iFace net.Interface, log *zap.Logger, discoverPort uint16, opts ...Option) (*Discovery, error) { addrs, err := iFace.Addrs() if err != nil { return nil, fmt.Errorf("get interface address: %w", err) } broadcastIP := net.IP(make([]byte, 4)) var ownIP net.IP for _, addr := range addrs { if ipnet, ok := addr.(*net.IPNet); ok { ip4 := ipnet.IP.To4() if ip4 == nil { continue } for i := range ip4 { broadcastIP[i] = ip4[i] | ^ipnet.Mask[i] } ownIP = ip4 break } } if broadcastIP.To4() == nil { return nil, fmt.Errorf("no broadcast address") } if ownIP == nil { return nil, fmt.Errorf("no own address") } broadcastAddrString := broadcastIP.String() + ":" + strconv.Itoa(int(discoverPort)) conn, err := net.ListenPacket("udp4", broadcastAddrString) if err != nil { return nil, fmt.Errorf("listen packet: %w", err) } udpAddr, err := net.ResolveUDPAddr("udp4", broadcastAddrString) if err != nil { return nil, fmt.Errorf("resolve udp address: %w", err) } d := Discovery{ log: log.With(zap.Stringer("broadcast_address", broadcastIP)), knownNodes: make(KnownNodes), newNodesCh: make(chan string, 20), failNodesCh: make(chan string), ownAddr: ownIP.To4(), ownAddrPacket: NewPacketWithIP(ownIP), conn: conn, broadcastAddr: udpAddr, broadcastInterval: defaultBroadcastInterval, } for _, opt := range opts { opt(&d) } return &d, nil } // Discovery a service to notify neighbours about yourself and keep track other neighbours alive. type Discovery struct { log *zap.Logger debug bool mu sync.Mutex knownNodes KnownNodes newNodesCh chan string failNodesCh chan string ownAddr net.IP ownAddrPacket Packet conn net.PacketConn broadcastAddr *net.UDPAddr broadcastInterval time.Duration } func (d *Discovery) Start(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() d.log.Info("start discovery", zap.Stringer("address", &d.ownAddr)) timer := time.NewTicker(d.broadcastInterval) defer timer.Stop() listenStop := make(chan struct{}) go func() { if err := d.listen(); err != nil { d.log.Error("listen failed", zap.Error(err)) } close(listenStop) }() stop := func() { if err := d.conn.Close(); err != nil { d.log.Warn("close connection", zap.Error(err)) } close(d.newNodesCh) } for { select { case <-ctx.Done(): stop() return case <-listenStop: d.log.Error("listener stopped, stop discovery") stop() return case <-timer.C: d.broadcast() } } } func (d *Discovery) listen() error { packet := NewPacket() for { readSize, addr, err := d.conn.ReadFrom(packet) if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { d.log.Warn("listen connection closed") return nil } if opErr, ok := err.(*net.OpError); ok && opErr.Temporary() { continue } return fmt.Errorf("read from: %w", err) } if !packet.MagicOk() { d.log.Warn("data without magic bytes received") continue } nodeAddr := packet.IP(readSize) d.log.Debug("received node address", zap.Stringer("address", nodeAddr)) clientIP, _, _ := strings.Cut(addr.String(), ":") if nodeAddr.String() != clientIP { d.log.Warn("received addr mismatch", zap.Stringer("received", nodeAddr), zap.String("detected", clientIP)) } d.addNode(nodeAddr.String()) } } func (d *Discovery) broadcast() { d.log.Debug("broadcast") if _, err := d.conn.WriteTo(d.ownAddrPacket, d.broadcastAddr); err != nil { d.log.Error("write broadcast message", zap.Error(err)) } } func (d *Discovery) addNode(addr string) { if !d.debug && addr == d.ownAddr.String() { return } d.mu.Lock() if _, ok := d.knownNodes[addr]; !ok { d.log.Info("new node address", zap.String("address", addr)) d.knownNodes[addr] = struct{}{} d.newNodesCh <- addr } d.mu.Unlock() } func (d *Discovery) removeNode(addr string) { d.mu.Lock() if _, ok := d.knownNodes[addr]; ok { d.log.Warn("node failed, removed", zap.String("address", addr)) delete(d.knownNodes, addr) } d.mu.Unlock() } // NewNodes returns channel with new discovered node addresses. func (d *Discovery) NewNodes() NewNodes { return d.newNodesCh } // FailNode remove address from list of known nodes. Next broadcast message from this node will be sent to the `NewNodes` channel. func (d *Discovery) FailNode(addr string) { d.removeNode(addr) }