package internal import ( "encoding/json" "fmt" "net/url" "sync" ) var connectionPool = sync.Pool{ New: func() any { return &connection{ send: make(chan *message, 256), } }, } func newConnection(topics []string) (c *connection) { c = connectionPool.Get().(*connection) c.id = uuidv7() c.topics = topics c.closed = false return } type connection struct { id string send chan *message topics []string closed bool } func (c *connection) Announce(h Hub, active bool) { for _, topic := range c.topics { b, _ := json.Marshal(c.toSubscription(topic, active)) h.Broadcast(newMessage( "Subscription", []string{subscriptionTopic}, string(b), )) } } func (c *connection) close() bool { if c.closed { return false } c.closed = true close(c.send) c.send = make(chan *message, 256) connectionPool.Put(c) return true } func (c *connection) toSubscription(topic string, active bool) subscription { return subscription{ ID: fmt.Sprintf("/.well-known/mercure/subscriptions/%s/%s", url.QueryEscape(topic), url.QueryEscape(c.id)), Type: "Subscription", Topic: topic, Subscriber: c.id, Active: active, Payload: make(map[string]any), } }
package internal import ( "context" "maps" "sync" ) type Hub interface { Run(context.Context) Register(*connection) Unregister(*connection) Broadcast(*message) Connections() map[*connection]bool } type hub struct { metrics *metrics subscriptions map[string]map[*connection]bool register chan *connection unregister chan *connection broadcast chan *message mutex sync.RWMutex } func newHub(m *metrics) *hub { return &hub{ metrics: m, subscriptions: make(map[string]map[*connection]bool), register: make(chan *connection), unregister: make(chan *connection), broadcast: make(chan *message), } } func (h *hub) Run(ctx context.Context) { for { select { case <-ctx.Done(): return case conn := <-h.register: h.mutex.Lock() for _, topic := range conn.topics { if _, ok := h.subscriptions[topic]; !ok { h.subscriptions[topic] = make(map[*connection]bool) } h.subscriptions[topic][conn] = true } h.mutex.Unlock() case conn := <-h.unregister: h.mutex.Lock() for _, topic := range conn.topics { delete(h.subscriptions[topic], conn) } conn.close() h.mutex.Unlock() case msg := <-h.broadcast: h.mutex.RLock() for _, t := range msg.Topics { if connections, ok := h.subscriptions[t]; ok { for conn := range connections { select { case conn.send <- msg: default: if conn.close() { h.metrics.Terminate() } } } } } h.mutex.RUnlock() } } } func (h *hub) Broadcast(msg *message) { h.broadcast <- msg } func (h *hub) Register(conn *connection) { h.register <- conn } func (h *hub) Unregister(conn *connection) { h.unregister <- conn } func (h *hub) Connections() map[*connection]bool { m2 := make(map[*connection]bool) h.mutex.RLock() for _, m := range h.subscriptions { maps.Copy(m2, m) } h.mutex.RUnlock() return m2 }
package internal import ( "context" "hash/crc32" "maps" ) type hubMulti struct { hubs []*hub } func newHubMulti(hubCount int, m *metrics) (h *hubMulti) { hubCount = max(hubCount, 1) h = &hubMulti{ hubs: make([]*hub, hubCount), } for i := range h.hubs { h.hubs[i] = newHub(m) } return } func (h *hubMulti) Run(ctx context.Context) { for _, h := range h.hubs { go h.Run(ctx) } return } func (h *hubMulti) Register(c *connection) { for _, topic := range c.topics { h.hubs[h.hash(topic)].Register(c) } return } func (h *hubMulti) Unregister(c *connection) { for _, topic := range c.topics { h.hubs[h.hash(topic)].Unregister(c) } return } func (h *hubMulti) Broadcast(m *message) { for _, topic := range m.Topics { h.hubs[h.hash(topic)].Broadcast(m) } return } func (h *hubMulti) Connections() map[*connection]bool { m2 := make(map[*connection]bool) for _, h := range h.hubs { maps.Copy(m2, h.Connections()) } return m2 } func (h *hubMulti) hash(topic string) int { return int(crc32.ChecksumIEEE([]byte(topic))) % len(h.hubs) }
package internal import ( "encoding/json" "io" "log" "net/http" "time" "github.com/lestrrat-go/httpcc" "github.com/lestrrat-go/jwx/v2/jwk" ) func jwksKeys(c *http.Client, url string) (keys []any, maxage time.Duration) { if len(url) > 0 { resp, err := c.Get(url) if err != nil { log.Println(err) return } defer resp.Body.Close() if resp.StatusCode != 200 { log.Printf("JWKS URL returned %d", resp.StatusCode) return } var jwks = map[string][]any{} body, err := io.ReadAll(resp.Body) if err != nil { log.Println(err) return } if err = json.Unmarshal(body, &jwks); err != nil { log.Println(err) return } for _, k := range jwks["keys"] { kjson, _ := json.Marshal(k) if err := jwk.ParseRawKey(kjson, &k); err != nil { log.Println(err) return } keys = append(keys, k) } maxage = 3600 directives, err := httpcc.ParseResponse(resp.Header.Get(`Cache-Control`)) if err == nil { val, present := directives.MaxAge() if present { maxage = time.Duration(max(int(val), 60)) } } } return }
package internal import ( "bytes" "crypto/ecdsa" "crypto/rsa" "crypto/x509" "encoding/pem" "log" "net/http" "slices" "strings" "github.com/golang-jwt/jwt/v5" ) var ( algECDSA = []string{"ES256", "ES384", "ES512"} algHMAC = []string{"HS256", "HS384", "HS512"} algRSA = []string{"RS256", "RS384", "RS512"} algRSAPSS = []string{"PS256", "PS384", "PS512"} algEdDSA = []string{"EdDSA"} ) func jwtKeys(alg, key string) (keys []any) { if alg == "" { return } if slices.Contains(algHMAC, alg) { for k := range bytes.SplitSeq([]byte(key), []byte("\n")) { keys = append(keys, k) } return } if slices.Contains(algECDSA, alg) || slices.Contains(algRSA, alg) || slices.Contains(algRSAPSS, alg) { return x509keys(alg, key) } if slices.Contains(algEdDSA, alg) { log.Println("EdDSA key alg not supported") return } log.Printf("Unrecognized key alg: %s", alg) return } func x509keys(alg, key string) (keys []any) { var rest = []byte(key) var block *pem.Block var i int for len(rest) > 0 { i++ block, rest = pem.Decode(rest) if block == nil { log.Printf("Unable to decode %s block #%d", alg, i) return } pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { log.Printf("Unable to parse key %s #%d", alg, i) return } switch alg[:2] { case "ES": if k, ok := pubInterface.(*ecdsa.PublicKey); ok { keys = append(keys, k) } case "RS": if k, ok := pubInterface.(*rsa.PublicKey); ok { keys = append(keys, k) } case "PS": if k, ok := pubInterface.(*rsa.PublicKey); ok { keys = append(keys, k) } } } return } type tokenClaims struct { Mercure struct { Publish []string `json:"publish"` Subscribe []string `json:"subscribe"` } `json:"mercure"` jwt.RegisteredClaims } func jwtTokenClaims(r *http.Request, keys []any, debug bool) *tokenClaims { tokenStr := r.Header.Get("Authorization") if parts := strings.Split(tokenStr, " "); len(parts) == 2 { tokenStr = parts[1] } if tokenStr == "" { cookies := r.CookiesNamed("mercureAuthorization") if len(cookies) > 0 { tokenStr = cookies[0].Value } } if tokenStr == "" { tokenStr = r.Form.Get("authorization") } if tokenStr == "" { return nil } claims := new(tokenClaims) var token *jwt.Token var err error for _, k := range keys { token, err = jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (any, error) { return k, nil }) if err != nil { continue } if token.Valid { break } } if token == nil || !token.Valid { if debug { log.Printf("Invalid token: %v: %s", err, tokenStr) } return nil } return token.Claims.(*tokenClaims) }
package internal import ( "encoding/json" "fmt" "io" ) type message struct { ID string Type string Topics []string Data string } func newMessage(msgType string, topics []string, data string) (m *message) { return &message{ ID: uuidv7(), Type: msgType, Topics: topics, Data: data, } } func (msg *message) WriteTo(w io.Writer) (in int64, err error) { var out []byte if len(msg.ID) > 0 { out = fmt.Appendf(out, "id: %v\n", msg.ID) } if len(msg.Type) > 0 { out = fmt.Appendf(out, "type: %v\n", msg.Type) } if len(msg.Data) > 0 { out = fmt.Appendf(out, "data: %s\n", msg.Data) } if len(out) == 0 { return } n, err := w.Write(append(out, []byte("\n")...)) return int64(n), err } func (msg *message) ToJson() (out []byte) { out, _ = json.Marshal(msg) return } func (msg *message) FromJson(in []byte) { json.Unmarshal(in, msg) } func (msg *message) timestamp() uint64 { return msgIDtimestamp(msg.ID) }
package internal import ( "context" "log" "net/http" "time" "github.com/logbn/mvfifo" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" ) type metrics struct { cache *mvfifo.Cache ctx context.Context listen string server *http.Server connections_active prometheus.Gauge connections_terminated prometheus.Counter connections_total prometheus.Counter message_cache_age prometheus.Gauge message_cache_count prometheus.Gauge message_cache_size prometheus.Gauge messages_published prometheus.Counter messages_sent prometheus.Counter subscriptions_active prometheus.Gauge subscriptions_total prometheus.Counter } func NewMetrics(listen string, cache *mvfifo.Cache) *metrics { return &metrics{listen: listen, cache: cache} } func (m *metrics) Start(ctx context.Context) { if m == nil || m.listen == "" { return } log.Printf("Starting metrics on %s", m.listen) m.server = &http.Server{ Addr: m.listen, Handler: promhttp.Handler(), } m.init() go func() { if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatal(err) } }() go func() { t := time.NewTicker(time.Second) defer t.Stop() for { select { case <-t.C: cur, _ := m.cache.First() m.message_cache_age.Set(float64((cur*100 - uint64(time.Now().UnixNano())) / uint64(time.Second))) m.message_cache_count.Set(float64(m.cache.Len())) m.message_cache_size.Set(float64(m.cache.Size())) case <-ctx.Done(): m.Stop() return } } }() } func (m *metrics) Stop() { if m != nil { m.server.Shutdown(m.ctx) } } func (m *metrics) Connect() { if m != nil { m.connections_total.Inc() m.connections_active.Inc() } } func (m *metrics) Disconnect() { if m != nil { m.connections_active.Dec() } } func (m *metrics) Terminate() { if m != nil { m.connections_terminated.Inc() } } func (m *metrics) Subscribe(n int) { if m != nil { m.subscriptions_total.Add(float64(n)) m.subscriptions_active.Add(float64(n)) } } func (m *metrics) Unsubscribe(n int) { if m != nil { m.subscriptions_active.Sub(float64(n)) } } func (m *metrics) Publish() { if m != nil { m.messages_published.Inc() } } func (m *metrics) Send() { if m != nil { m.messages_sent.Inc() } } func (m *metrics) init() { m.connections_active = promauto.NewGauge(prometheus.GaugeOpts{ Name: "mercure_lite_connections_active", Help: "Number of active connections", }) m.connections_total = promauto.NewCounter(prometheus.CounterOpts{ Name: "mercure_lite_connections_total", Help: "Total number of connections created", }) m.connections_terminated = promauto.NewCounter(prometheus.CounterOpts{ Name: "mercure_lite_connections_terminated", Help: "Total number of connections terminated", }) m.message_cache_age = promauto.NewGauge(prometheus.GaugeOpts{ Name: "mercure_lite_message_cache_age", Help: "Age of oldest message in the cache", }) m.message_cache_count = promauto.NewGauge(prometheus.GaugeOpts{ Name: "mercure_lite_message_cache_count", Help: "Number of messages presently stored in the cache", }) m.message_cache_size = promauto.NewGauge(prometheus.GaugeOpts{ Name: "mercure_lite_message_cache_size", Help: "Number of bytes presently stored in the cache", }) m.messages_published = promauto.NewCounter(prometheus.CounterOpts{ Name: "mercure_lite_messages_published", Help: "Total number of messages published", }) m.messages_sent = promauto.NewCounter(prometheus.CounterOpts{ Name: "mercure_lite_messages_sent", Help: "Total number of messages sent", }) m.subscriptions_active = promauto.NewGauge(prometheus.GaugeOpts{ Name: "mercure_lite_subscriptions_active", Help: "Number of active subsriptions", }) m.subscriptions_total = promauto.NewCounter(prometheus.CounterOpts{ Name: "mercure_lite_subscriptions_total", Help: "Total number of subscriptions created", }) }
package internal import ( "context" "encoding/json" "fmt" "log" "net/http" "slices" "strings" "sync" "time" "github.com/benbjohnson/clock" "github.com/logbn/expset" "github.com/logbn/mvfifo" "github.com/yosida95/uritemplate" ) var ( subscriptionTopic = "/.well-known/mercure/subscriptions/topic/subscriber" pingPeriod = 30 * time.Second ) type server struct { cache *mvfifo.Cache cfg Config clock clock.Clock ctx context.Context ctxCancel context.CancelFunc done chan bool httpClient *http.Client hub Hub metrics *metrics mutex sync.RWMutex pubJwksRefresh time.Duration pubKeys []any pubKeysJwks []any recentTopics *expset.Set[string] server *http.Server subJwksRefresh time.Duration subKeys []any subKeysJwks []any } func NewServer(cfg Config) *server { var cache = mvfifo.NewCache(mvfifo.WithMaxSizeBytes(max(cfg.CACHE_SIZE_MB, 16) << 20)) var m *metrics if len(cfg.METRICS) > 0 { m = NewMetrics(cfg.METRICS, cache) } return &server{ cache: cache, cfg: cfg, clock: clock.New(), httpClient: &http.Client{Timeout: 5 * time.Second}, hub: newHubMulti(cfg.HUB_COUNT, m), metrics: m, recentTopics: expset.New[string](), } } func (s *server) Start(ctx context.Context) (err error) { if s.ctx != nil { return } var maxage time.Duration s.ctx, s.ctxCancel = context.WithCancel(ctx) s.pubKeys = jwtKeys(s.cfg.PUBLISHER.JWT_ALG, s.cfg.PUBLISHER.JWT_KEY) s.pubKeysJwks, maxage = jwksKeys(s.httpClient, s.cfg.PUBLISHER.JWKS_URL) if len(s.allPubKeys()) == 0 { return fmt.Errorf("No publish keys available") } s.pubJwksRefresh = time.Duration(max(int(maxage), 60)) * time.Second s.subKeys = jwtKeys(s.cfg.SUBSCRIBER.JWT_ALG, s.cfg.SUBSCRIBER.JWT_KEY) s.subKeysJwks, maxage = jwksKeys(s.httpClient, s.cfg.SUBSCRIBER.JWKS_URL) if len(s.allSubKeys()) == 0 { return fmt.Errorf("No subscriber keys available") } s.subJwksRefresh = time.Duration(max(int(maxage), 60)) * time.Second s.done = make(chan bool) s.startJwksRefresh() go s.hub.Run(s.ctx) s.server = &http.Server{ Addr: s.cfg.LISTEN, Handler: s, } go func() { log.Printf("Starting server on %s", s.cfg.LISTEN) if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Error in ListenAndServe: %s", err) } }() go func() { select { case <-ctx.Done(): s.Stop() case <-s.done: } }() s.metrics.Start(ctx) return nil } func (s *server) Stop() { if s.ctx == nil { return } close(s.done) timeout, cancel := context.WithTimeout(s.ctx, time.Second) defer cancel() s.server.Shutdown(timeout) s.metrics.Stop() s.ctxCancel() s.ctx = nil } func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.URL.Path, "/.well-known/mercure") { w.WriteHeader(404) return } switch r.URL.Path { case "/.well-known/mercure": switch strings.ToUpper(r.Method) { case "POST": s.publish(w, r) case "OPTIONS": s.options(w, r) case "GET": s.subscribe(w, r) default: w.WriteHeader(405) } case "/.well-known/mercure/subscriptions": switch strings.ToUpper(r.Method) { case "GET": s.list(w, r) default: w.WriteHeader(405) } default: w.WriteHeader(404) } } func (s *server) publish(w http.ResponseWriter, r *http.Request) { r.ParseForm() msg := newMessage( r.Form.Get("type"), s.verifyPublish(r, r.Form["topic"]), r.Form.Get("data"), ) if len(msg.Topics) == 0 { w.WriteHeader(403) return } for _, topic := range msg.Topics { if s.recentTopics.Has(topic) { s.cache.Add(topic, msg.timestamp(), msg.ToJson()) } } s.hub.Broadcast(msg) w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Write([]byte(msg.ID)) s.metrics.Publish() } func (s *server) options(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Access-Control-Allow-Origin", s.cfg.CORS_ORIGINS) w.Header().Set("Access-Control-Allow-Headers", "Authorization, Last-Event-ID, Cache-Control") w.Header().Set("Access-Control-Allow-Credentials", "true") } func (s *server) subscribe(w http.ResponseWriter, r *http.Request) { r.ParseForm() topics := r.Form["topic"] topics, jwtExpires := s.verifySubscribe(r, topics) if len(topics) < 1 { return } if jwtExpires < 1 { jwtExpires = time.Hour * 1e6 } topics, err := s.normalize(topics) if err != nil { log.Print(err) w.WriteHeader(400) return } for _, topic := range topics { s.recentTopics.Add(topic, time.Hour) } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "private, no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Access-Control-Allow-Origin", s.cfg.CORS_ORIGINS) w.Header().Set("Access-Control-Allow-Headers", "Authorization") w.Header().Set("Access-Control-Allow-Credentials", "true") lastEventID := r.Header.Get("Last-Event-ID") lastEventCursor := msgIDtimestamp(lastEventID) if lastEventCursor > 0 { var msg = &message{} for _, topic := range topics { for _, data := range s.cache.IterAfter(topic, lastEventCursor) { msg.FromJson(data) msg.WriteTo(w) } } } if _, err := w.Write([]byte(":\n")); err != nil { return } conn := newConnection(topics) conn.Announce(s.hub, true) s.hub.Register(conn) defer s.hub.Unregister(conn) defer conn.Announce(s.hub, false) defer s.metrics.Disconnect() defer s.metrics.Unsubscribe(len(conn.topics)) s.metrics.Connect() s.metrics.Subscribe(len(conn.topics)) flush := w.(http.Flusher).Flush flush() ping := time.NewTicker(pingPeriod) defer ping.Stop() var last string for { select { case msg, ok := <-conn.send: if !ok { return } if msg.ID == last { break } if _, err := msg.WriteTo(w); err != nil { return } flush() last = msg.ID s.metrics.Send() case <-ping.C: if _, err := w.Write([]byte(":\n")); err != nil { return } flush() for _, topic := range topics { s.recentTopics.Add(topic, time.Hour) } case <-r.Context().Done(): return case <-s.clock.After(jwtExpires): return } } } func (s *server) normalize(topics []string) ([]string, error) { for i := range topics { t, err := uritemplate.New(topics[i]) if err != nil { return topics, fmt.Errorf("Invalid topic: %s", topics[i]) } if t.Match(subscriptionTopic) != nil { topics[i] = subscriptionTopic } } return topics, nil } func (s *server) verifySubscribe(r *http.Request, topics []string) (res []string, jwtExpires time.Duration) { claims := jwtTokenClaims(r, s.allSubKeys(), s.cfg.DEBUG) if claims == nil { return } if claims.RegisteredClaims.ExpiresAt != nil { jwtExpires = s.clock.Until(claims.RegisteredClaims.ExpiresAt.Truncate(time.Second)) } all := slices.Contains(claims.Mercure.Subscribe, "*") for _, t := range topics { if all || slices.Contains(claims.Mercure.Subscribe, t) { res = append(res, t) } } return } func (s *server) verifyPublish(r *http.Request, topics []string) (res []string) { claims := jwtTokenClaims(r, s.allPubKeys(), s.cfg.DEBUG) if claims == nil { return } all := slices.Contains(claims.Mercure.Publish, "*") for _, t := range topics { if all || slices.Contains(claims.Mercure.Publish, t) { res = append(res, t) } } return } func (s *server) list(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/ld+json") list := subscriptionList{ Context: "github.com/pantopic/mercure-lite", ID: subscriptionTopic, Type: "Subscriptions", LastEventID: uuidv7(), } for c := range s.hub.Connections() { for _, topic := range c.topics { list.Subscriptions = append(list.Subscriptions, c.toSubscription(topic, true)) } } b, _ := json.Marshal(list) w.Write(b) } func (s *server) allPubKeys() []any { s.mutex.RLock() defer s.mutex.RUnlock() return append(s.pubKeys, s.pubKeysJwks...) } func (s *server) allSubKeys() []any { s.mutex.RLock() defer s.mutex.RUnlock() return append(s.subKeys, s.subKeysJwks...) } func (s *server) startJwksRefresh() { if s.subJwksRefresh > 0 { go func() { t := s.clock.Ticker(s.subJwksRefresh) defer t.Stop() for { select { case <-t.C: keys, maxage := jwksKeys(s.httpClient, s.cfg.SUBSCRIBER.JWKS_URL) if maxage != s.subJwksRefresh && maxage > 0 { s.subJwksRefresh = maxage * time.Second t.Reset(s.subJwksRefresh) } if len(keys) < 1 { break } s.mutex.Lock() s.subKeysJwks = keys s.mutex.Unlock() case <-s.done: return } } }() } if s.pubJwksRefresh > 0 { go func() { t := s.clock.Ticker(s.pubJwksRefresh) defer t.Stop() for { select { case <-t.C: keys, maxage := jwksKeys(s.httpClient, s.cfg.PUBLISHER.JWKS_URL) if maxage != s.pubJwksRefresh && maxage > 0 { s.pubJwksRefresh = maxage * time.Second t.Reset(s.pubJwksRefresh) } if len(keys) < 1 { break } s.mutex.Lock() s.pubKeysJwks = keys s.mutex.Unlock() case <-s.done: return } } }() } }
package internal import ( "fmt" "github.com/gofrs/uuid/v5" "github.com/pantopic/mercure-lite" ) type ( Config = mercurelite.Config ConfigJWT = mercurelite.ConfigJWT ) func uuidv7() string { uuid, _ := uuid.NewV7() return fmt.Sprintf("urn:uuid:%s", uuid) } func msgIDtimestamp(id string) uint64 { if len(id) != 45 { return 0 } u, err := uuid.FromString(id[9:]) if err != nil { return 0 } t, err := uuid.TimestampFromV7(u) if err != nil { return 0 } return uint64(t) }