269 lines
5.5 KiB
Go
269 lines
5.5 KiB
Go
package discovery
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const defaultBroadcastInterval = time.Second
|
|
|
|
type (
|
|
KnownNodes map[string]struct{}
|
|
NewNodes <-chan string
|
|
)
|
|
|
|
// NewDiscoverySet returns a set of discovery services for all running and non-loopback network interfaces.
|
|
func NewDiscoverySet(log *zap.Logger, discoverPort uint16, opts ...Option) (DiscoverySet, error) {
|
|
iFaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list interfaces: %w", err)
|
|
}
|
|
|
|
set := make(DiscoverySet, 0, len(iFaces))
|
|
var errs []error
|
|
|
|
for _, iFace := range iFaces {
|
|
if iFace.Flags&net.FlagLoopback == net.FlagLoopback {
|
|
continue
|
|
}
|
|
|
|
if iFace.Flags&net.FlagRunning != net.FlagRunning {
|
|
continue
|
|
}
|
|
|
|
discover, err := NewDiscovery(iFace, log, discoverPort, opts...)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
|
|
continue
|
|
}
|
|
|
|
set = append(set, discover)
|
|
}
|
|
|
|
if len(set) == 0 {
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
return set, nil
|
|
}
|
|
|
|
// NewDiscovery returns new initialized discovery service for specified network interface.
|
|
func NewDiscovery(iFace net.Interface, log *zap.Logger, discoverPort uint16, opts ...Option) (*Discovery, error) {
|
|
addrs, err := iFace.Addrs()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get interface address: %w", err)
|
|
}
|
|
|
|
broadcastIP := net.IP(make([]byte, 4))
|
|
var ownIP net.IP
|
|
|
|
for _, addr := range addrs {
|
|
if ipnet, ok := addr.(*net.IPNet); ok {
|
|
ip4 := ipnet.IP.To4()
|
|
if ip4 == nil {
|
|
continue
|
|
}
|
|
|
|
for i := range ip4 {
|
|
broadcastIP[i] = ip4[i] | ^ipnet.Mask[i]
|
|
}
|
|
|
|
ownIP = ip4
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
if broadcastIP.To4() == nil {
|
|
return nil, fmt.Errorf("no broadcast address")
|
|
}
|
|
|
|
if ownIP == nil {
|
|
return nil, fmt.Errorf("no own address")
|
|
}
|
|
|
|
broadcastAddrString := broadcastIP.String() + ":" + strconv.Itoa(int(discoverPort))
|
|
|
|
conn, err := net.ListenPacket("udp4", broadcastAddrString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen packet: %w", err)
|
|
}
|
|
|
|
udpAddr, err := net.ResolveUDPAddr("udp4", broadcastAddrString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("resolve udp address: %w", err)
|
|
}
|
|
|
|
d := Discovery{
|
|
log: log.With(zap.Stringer("broadcast_address", broadcastIP)),
|
|
knownNodes: make(KnownNodes),
|
|
newNodesCh: make(chan string, 20),
|
|
failNodesCh: make(chan string),
|
|
ownAddr: ownIP.To4(),
|
|
ownAddrPacket: NewPacketWithIP(ownIP),
|
|
conn: conn,
|
|
broadcastAddr: udpAddr,
|
|
broadcastInterval: defaultBroadcastInterval,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(&d)
|
|
}
|
|
|
|
return &d, nil
|
|
}
|
|
|
|
// Discovery a service to notify neighbours about yourself and keep track other neighbours alive.
|
|
type Discovery struct {
|
|
log *zap.Logger
|
|
debug bool
|
|
|
|
mu sync.Mutex
|
|
knownNodes KnownNodes
|
|
|
|
newNodesCh chan string
|
|
failNodesCh chan string
|
|
|
|
ownAddr net.IP
|
|
ownAddrPacket Packet
|
|
conn net.PacketConn
|
|
broadcastAddr *net.UDPAddr
|
|
broadcastInterval time.Duration
|
|
}
|
|
|
|
func (d *Discovery) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
|
|
d.log.Info("start discovery", zap.Stringer("address", &d.ownAddr))
|
|
|
|
timer := time.NewTicker(d.broadcastInterval)
|
|
defer timer.Stop()
|
|
|
|
listenStop := make(chan struct{})
|
|
|
|
go func() {
|
|
if err := d.listen(); err != nil {
|
|
d.log.Error("listen failed", zap.Error(err))
|
|
}
|
|
|
|
close(listenStop)
|
|
}()
|
|
|
|
stop := func() {
|
|
if err := d.conn.Close(); err != nil {
|
|
d.log.Warn("close connection", zap.Error(err))
|
|
}
|
|
|
|
close(d.newNodesCh)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
stop()
|
|
|
|
return
|
|
|
|
case <-listenStop:
|
|
d.log.Error("listener stopped, stop discovery")
|
|
stop()
|
|
|
|
return
|
|
|
|
case <-timer.C:
|
|
d.broadcast()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Discovery) listen() error {
|
|
packet := NewPacket()
|
|
|
|
for {
|
|
readSize, addr, err := d.conn.ReadFrom(packet)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "use of closed network connection") {
|
|
d.log.Warn("listen connection closed")
|
|
|
|
return nil
|
|
}
|
|
|
|
if opErr, ok := err.(*net.OpError); ok && opErr.Temporary() {
|
|
continue
|
|
}
|
|
|
|
return fmt.Errorf("read from: %w", err)
|
|
}
|
|
|
|
if !packet.MagicOk() {
|
|
d.log.Warn("data without magic bytes received")
|
|
|
|
continue
|
|
}
|
|
|
|
nodeAddr := packet.IP(readSize)
|
|
d.log.Debug("received node address", zap.Stringer("address", nodeAddr))
|
|
|
|
clientIP, _, _ := strings.Cut(addr.String(), ":")
|
|
if nodeAddr.String() != clientIP {
|
|
d.log.Warn("received addr mismatch", zap.Stringer("received", nodeAddr), zap.String("detected", clientIP))
|
|
}
|
|
|
|
d.addNode(nodeAddr.String())
|
|
}
|
|
}
|
|
|
|
func (d *Discovery) broadcast() {
|
|
d.log.Debug("broadcast")
|
|
|
|
if _, err := d.conn.WriteTo(d.ownAddrPacket, d.broadcastAddr); err != nil {
|
|
d.log.Error("write broadcast message", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
func (d *Discovery) addNode(addr string) {
|
|
if !d.debug && addr == d.ownAddr.String() {
|
|
return
|
|
}
|
|
|
|
d.mu.Lock()
|
|
|
|
if _, ok := d.knownNodes[addr]; !ok {
|
|
d.log.Info("new node address", zap.String("address", addr))
|
|
d.knownNodes[addr] = struct{}{}
|
|
d.newNodesCh <- addr
|
|
}
|
|
|
|
d.mu.Unlock()
|
|
}
|
|
|
|
func (d *Discovery) removeNode(addr string) {
|
|
d.mu.Lock()
|
|
|
|
if _, ok := d.knownNodes[addr]; ok {
|
|
d.log.Warn("node failed, removed", zap.String("address", addr))
|
|
delete(d.knownNodes, addr)
|
|
}
|
|
|
|
d.mu.Unlock()
|
|
}
|
|
|
|
// NewNodes returns channel with new discovered node addresses.
|
|
func (d *Discovery) NewNodes() NewNodes {
|
|
return d.newNodesCh
|
|
}
|
|
|
|
// FailNode remove address from list of known nodes. Next broadcast message from this node will be sent to the `NewNodes` channel.
|
|
func (d *Discovery) FailNode(addr string) {
|
|
d.removeNode(addr)
|
|
}
|