diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 7e47a0b6..c1cf4f54 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -28,6 +28,7 @@ import ( oapi "github.com/onkernel/kernel-images/server/lib/oapi" "github.com/onkernel/kernel-images/server/lib/recorder" "github.com/onkernel/kernel-images/server/lib/scaletozero" + "github.com/onkernel/kernel-images/server/lib/webrtcscreen" ) func main() { @@ -135,6 +136,22 @@ func main() { fs.ServeHTTP(w, r) }) + // WebRTC relay: connects to Neko as a headless viewer and re-serves + // the VP8 video stream to external WebRTC clients via a single + // WebSocket signaling endpoint. The Neko connection is lazy — + // it only starts when the first client connects. + relay, err := webrtcscreen.NewRelay(ctx, webrtcscreen.RelayConfig{ + NekoBaseURL: "http://127.0.0.1:8080", + NekoUser: "admin", + NekoPass: adminPassword, + Logger: slogger, + }) + if err != nil { + slogger.Error("failed to create webrtc relay", "err", err) + os.Exit(1) + } + r.Get("/display/webrtc", relay.HandleWebSocket) + srv := &http.Server{ Addr: fmt.Sprintf(":%d", config.Port), Handler: r, diff --git a/server/cmd/webrtc-screenshot/main.go b/server/cmd/webrtc-screenshot/main.go new file mode 100644 index 00000000..67ccf849 --- /dev/null +++ b/server/cmd/webrtc-screenshot/main.go @@ -0,0 +1,300 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "image/jpeg" + "log/slog" + "os" + "os/signal" + "path/filepath" + "sync" + "syscall" + "time" + + cws "github.com/coder/websocket" + "github.com/onkernel/kernel-images/server/lib/vpxdecoder" + "github.com/pion/rtp/codecs" + "github.com/pion/webrtc/v3" +) + +func main() { + serverURL := flag.String("server", "ws://127.0.0.1:10001/display/webrtc", "WebRTC signaling WebSocket URL") + outputPath := flag.String("output", "/tmp/screen.jpg", "Path to write JPEG screenshots") + quality := flag.Int("quality", 85, "JPEG quality (1-100)") + flag.Parse() + + if *quality < 1 || *quality > 100 { + fmt.Fprintf(os.Stderr, "error: --quality must be between 1 and 100, got %d\n", *quality) + os.Exit(1) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})) + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + fb := &frameBuffer{ + path: *outputPath, + quality: *quality, + logger: logger, + } + + for { + err := run(ctx, logger, *serverURL, fb) + if ctx.Err() != nil { + logger.Info("shutting down") + return + } + logger.Warn("connection lost, reconnecting in 2s", "error", err) + select { + case <-time.After(2 * time.Second): + case <-ctx.Done(): + return + } + } +} + +func run(ctx context.Context, logger *slog.Logger, serverURL string, fb *frameBuffer) error { + connectCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + // Connect to signaling WebSocket. + ws, _, err := cws.Dial(connectCtx, serverURL, nil) + if err != nil { + return fmt.Errorf("ws dial: %w", err) + } + defer ws.Close(cws.StatusGoingAway, "done") + + // Create PeerConnection. + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + return fmt.Errorf("new peer connection: %w", err) + } + defer pc.Close() + + // We want to receive video only. + if _, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }); err != nil { + return fmt.Errorf("add transceiver: %w", err) + } + + trackCh := make(chan *webrtc.TrackRemote, 1) + pc.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() == webrtc.RTPCodecTypeVideo { + select { + case trackCh <- track: + default: + } + } + }) + + disconnected := make(chan struct{}) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + logger.Info("peer connection state", "state", state.String()) + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + select { + case <-disconnected: + default: + close(disconnected) + } + } + }) + + // Create offer (with all ICE candidates gathered). + offer, err := pc.CreateOffer(nil) + if err != nil { + return fmt.Errorf("create offer: %w", err) + } + gatherDone := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(offer); err != nil { + return fmt.Errorf("set local desc: %w", err) + } + select { + case <-gatherDone: + case <-connectCtx.Done(): + return connectCtx.Err() + } + + // Send offer to server. + offerMsg, _ := json.Marshal(map[string]string{ + "type": "offer", + "sdp": pc.LocalDescription().SDP, + }) + if err := ws.Write(connectCtx, cws.MessageText, offerMsg); err != nil { + return fmt.Errorf("send offer: %w", err) + } + + // Receive answer. + _, answerData, err := ws.Read(connectCtx) + if err != nil { + return fmt.Errorf("read answer: %w", err) + } + var answer struct { + Type string `json:"type"` + SDP string `json:"sdp"` + } + if err := json.Unmarshal(answerData, &answer); err != nil { + return fmt.Errorf("invalid answer: %w", err) + } + if answer.Type != "answer" { + return fmt.Errorf("unexpected message type: got %q, want \"answer\"", answer.Type) + } + + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: answer.SDP, + }); err != nil { + return fmt.Errorf("set remote desc: %w", err) + } + + logger.Info("WebRTC connected, waiting for video track") + + // Wait for video track. + var track *webrtc.TrackRemote + select { + case track = <-trackCh: + logger.Info("video track received", + "codec", track.Codec().MimeType, + "ssrc", track.SSRC(), + ) + case <-time.After(10 * time.Second): + return fmt.Errorf("timeout waiting for video track") + case <-ctx.Done(): + return ctx.Err() + } + + // Decode loop: depacketize VP8, decode every frame, write JPEG. + return fb.decodeLoop(ctx, track, disconnected) +} + +// frameBuffer holds the VP8 decoder state and handles writing JPEGs. +type frameBuffer struct { + path string + quality int + logger *slog.Logger + + mu sync.Mutex + frames int64 +} + +func (fb *frameBuffer) decodeLoop(ctx context.Context, track *webrtc.TrackRemote, disconnected <-chan struct{}) error { + dec, err := vpxdecoder.New() + if err != nil { + return fmt.Errorf("vpx decoder init: %w", err) + } + defer dec.Close() + + fb.mu.Lock() + fb.frames = 0 + fb.mu.Unlock() + + var ( + frameBuf bytes.Buffer + frameStarted bool + ) + + statsStart := time.Now() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-disconnected: + return fmt.Errorf("peer connection lost") + default: + } + + pkt, _, err := track.ReadRTP() + if err != nil { + return fmt.Errorf("read rtp: %w", err) + } + + // Depacketize VP8 from RTP. + vp8Pkt := &codecs.VP8Packet{} + payload, err := vp8Pkt.Unmarshal(pkt.Payload) + if err != nil { + continue + } + + // S=1 + PID=0 → start of new frame. + if vp8Pkt.S == 1 && vp8Pkt.PID == 0 { + frameBuf.Reset() + frameStarted = true + } + + if !frameStarted { + continue + } + + frameBuf.Write(payload) + + // Marker bit → last packet of frame. + if !pkt.Marker { + continue + } + frameStarted = false + + if frameBuf.Len() == 0 { + continue + } + + img, err := dec.Decode(frameBuf.Bytes()) + if err != nil { + fb.logger.Debug("decode failed", "error", err, "size", frameBuf.Len()) + continue + } + + var jpegBuf bytes.Buffer + if err := jpeg.Encode(&jpegBuf, img, &jpeg.Options{Quality: fb.quality}); err != nil { + fb.logger.Warn("jpeg encode failed", "error", err) + continue + } + + fb.writeToFile(jpegBuf.Bytes()) + + fb.mu.Lock() + fb.frames++ + count := fb.frames + fb.mu.Unlock() + + if count%100 == 0 { + elapsed := time.Since(statsStart) + fb.logger.Info("frame stats", + "frames", count, + "fps", fmt.Sprintf("%.1f", float64(count)/elapsed.Seconds()), + "size_kb", jpegBuf.Len()/1024, + "resolution", fmt.Sprintf("%dx%d", img.Rect.Dx(), img.Rect.Dy()), + ) + } + } +} + +func (fb *frameBuffer) writeToFile(data []byte) { + dir := filepath.Dir(fb.path) + tmp, err := os.CreateTemp(dir, ".screenshot-*.tmp") + if err != nil { + fb.logger.Warn("failed to create temp file", "error", err) + return + } + tmpName := tmp.Name() + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + os.Remove(tmpName) + return + } + if err := tmp.Close(); err != nil { + os.Remove(tmpName) + return + } + if err := os.Rename(tmpName, fb.path); err != nil { + fb.logger.Warn("rename failed", "error", err) + os.Remove(tmpName) + return + } +} diff --git a/server/go.mod b/server/go.mod index ea5629b9..6cc04b53 100644 --- a/server/go.mod +++ b/server/go.mod @@ -19,10 +19,13 @@ require ( github.com/m1k1o/neko/server v0.0.0-20251008185748-46e2fc7d3866 github.com/nrednav/cuid2 v1.1.0 github.com/oapi-codegen/runtime v1.1.2 + github.com/pion/rtcp v1.2.15 + github.com/pion/rtp v1.8.21 + github.com/pion/webrtc/v3 v3.3.6 github.com/samber/lo v1.52.0 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 - golang.org/x/sync v0.17.0 + golang.org/x/sync v0.19.0 golang.org/x/sys v0.39.0 golang.org/x/term v0.37.0 ) @@ -72,6 +75,20 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pion/datachannel v1.5.10 // indirect + github.com/pion/dtls/v2 v2.2.12 // indirect + github.com/pion/ice/v2 v2.3.38 // indirect + github.com/pion/interceptor v0.1.40 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns v0.0.12 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/sctp v1.8.39 // indirect + github.com/pion/sdp/v3 v3.0.15 // indirect + github.com/pion/srtp/v2 v2.0.20 // indirect + github.com/pion/stun v0.6.1 // indirect + github.com/pion/transport/v2 v2.2.10 // indirect + github.com/pion/transport/v3 v3.0.7 // indirect + github.com/pion/turn/v2 v2.1.6 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect @@ -80,6 +97,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/wlynxg/anet v0.0.5 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect @@ -90,7 +108,8 @@ require ( go.opentelemetry.io/otel/trace v1.39.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect golang.org/x/crypto v0.43.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/net v0.45.0 // indirect + golang.org/x/text v0.34.0 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/server/go.sum b/server/go.sum index ac2dd499..7a16705c 100644 --- a/server/go.sum +++ b/server/go.sum @@ -78,6 +78,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= @@ -93,8 +94,11 @@ github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dv github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= @@ -141,6 +145,49 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= +github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= +github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= +github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= +github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= +github.com/pion/ice/v2 v2.3.38 h1:DEpt13igPfvkE2+1Q+6e8mP30dtWnQD3CtMIKoRDRmA= +github.com/pion/ice/v2 v2.3.38/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ= +github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4= +github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= +github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.12/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= +github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= +github.com/pion/rtp v1.8.3/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/rtp v1.8.21 h1:3yrOwmZFyUpcIosNcWRpQaU+UXIJ6yxLuJ8Bx0mw37Y= +github.com/pion/rtp v1.8.21/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk= +github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE= +github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= +github.com/pion/sdp/v3 v3.0.15 h1:F0I1zds+K/+37ZrzdADmx2Q44OFDOPRLhPnNTaUX9hk= +github.com/pion/sdp/v3 v3.0.15/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= +github.com/pion/srtp/v2 v2.0.20 h1:HNNny4s+OUmG280ETrCdgFndp4ufx3/uy85EawYEhTk= +github.com/pion/srtp/v2 v2.0.20/go.mod h1:0KJQjA99A6/a0DOVTu1PhDSw0CXF2jTkqOoMg3ODqdA= +github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= +github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= +github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= +github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= +github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= +github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q= +github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E= +github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= +github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc= +github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= +github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE= +github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -161,10 +208,16 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= @@ -175,6 +228,10 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= @@ -195,27 +252,73 @@ go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6 go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= @@ -226,6 +329,7 @@ google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/server/lib/vpxdecoder/decoder.go b/server/lib/vpxdecoder/decoder.go new file mode 100644 index 00000000..48b72ccc --- /dev/null +++ b/server/lib/vpxdecoder/decoder.go @@ -0,0 +1,100 @@ +package vpxdecoder + +/* +#cgo pkg-config: vpx +#include +#include +#include +#include + +// vpx_codec_dec_init is a macro; wrap it for CGo. +static vpx_codec_err_t init_vp8_decoder(vpx_codec_ctx_t *ctx) { + vpx_codec_dec_cfg_t cfg = {0}; + return vpx_codec_dec_init(ctx, vpx_codec_vp8_dx(), &cfg, 0); +} +*/ +import "C" + +import ( + "fmt" + "image" + "unsafe" +) + +// Decoder is a VP8 video decoder backed by libvpx via CGo. +// It decodes both keyframes AND inter-frames, maintaining full +// reference frame state, so every frame produces a decoded image. +type Decoder struct { + codec C.vpx_codec_ctx_t +} + +func New() (*Decoder, error) { + d := &Decoder{} + status := C.init_vp8_decoder(&d.codec) + if status != C.VPX_CODEC_OK { + return nil, fmt.Errorf("vpx_codec_dec_init failed: status=%d", int(status)) + } + return d, nil +} + +// Decode decodes a single VP8 frame (keyframe or inter-frame) and returns +// the decoded image as YCbCr 4:2:0. The returned image's data is valid +// until the next Decode call (libvpx reuses internal buffers). +func (d *Decoder) Decode(data []byte) (*image.YCbCr, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty frame data") + } + + status := C.vpx_codec_decode( + &d.codec, + (*C.uint8_t)(unsafe.Pointer(&data[0])), + C.uint(len(data)), + nil, 0, + ) + if status != C.VPX_CODEC_OK { + return nil, fmt.Errorf("vpx_codec_decode failed: status=%d", int(status)) + } + + var iter C.vpx_codec_iter_t + img := C.vpx_codec_get_frame(&d.codec, &iter) + if img == nil { + return nil, fmt.Errorf("no frame available after decode") + } + + w := int(img.d_w) + h := int(img.d_h) + + // libvpx outputs I420 (YUV 4:2:0 planar). + yStride := int(img.stride[0]) + uStride := int(img.stride[1]) + vStride := int(img.stride[2]) + + // Copy plane data to Go-managed memory so it's safe to hold + // after the next Decode call. + yLen := yStride * h + uLen := uStride * ((h + 1) / 2) + vLen := vStride * ((h + 1) / 2) + + yData := C.GoBytes(unsafe.Pointer(img.planes[0]), C.int(yLen)) + uData := C.GoBytes(unsafe.Pointer(img.planes[1]), C.int(uLen)) + vData := C.GoBytes(unsafe.Pointer(img.planes[2]), C.int(vLen)) + + cStride := uStride + if vStride > uStride { + cStride = vStride + } + + return &image.YCbCr{ + Y: yData, + Cb: uData, + Cr: vData, + YStride: yStride, + CStride: cStride, + SubsampleRatio: image.YCbCrSubsampleRatio420, + Rect: image.Rect(0, 0, w, h), + }, nil +} + +func (d *Decoder) Close() { + C.vpx_codec_destroy(&d.codec) +} diff --git a/server/lib/webrtcscreen/relay.go b/server/lib/webrtcscreen/relay.go new file mode 100644 index 00000000..a2e045f8 --- /dev/null +++ b/server/lib/webrtcscreen/relay.go @@ -0,0 +1,585 @@ +package webrtcscreen + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + cws "github.com/coder/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" +) + +// Relay is a WebRTC SFU that connects to a local Neko instance, +// receives its VP8 video stream, and re-serves it to external +// WebRTC clients. The API server mounts HandleWebSocket on a +// single endpoint (e.g., /display/webrtc) — that's all external +// clients need to connect and receive the live screen. +// +// The Neko connection is lazy: it is only established when the +// first client connects via HandleWebSocket. +type Relay struct { + logger *slog.Logger + cfg RelayConfig + ctx context.Context + + mu sync.RWMutex + localTrack *webrtc.TrackLocalStaticRTP + nekoPC *webrtc.PeerConnection + nekoWS *cws.Conn + ready chan struct{} // closed when localTrack is receiving data + + startOnce sync.Once +} + +type RelayConfig struct { + NekoBaseURL string + NekoUser string + NekoPass string + Logger *slog.Logger +} + +func NewRelay(ctx context.Context, cfg RelayConfig) (*Relay, error) { + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + + localTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeVP8, ClockRate: 90000}, + "video", "screen", + ) + if err != nil { + return nil, fmt.Errorf("creating local track: %w", err) + } + + return &Relay{ + logger: cfg.Logger.With("component", "webrtc-relay"), + cfg: cfg, + ctx: ctx, + localTrack: localTrack, + ready: make(chan struct{}), + }, nil +} + +// ensureRunning starts the Neko connection loop in the background +// on the first call. Subsequent calls are no-ops. +func (r *Relay) ensureRunning() { + r.startOnce.Do(func() { + r.logger.Info("first client request, starting neko connection") + go func() { + defer r.Close() + for { + err := r.Start(r.ctx) + if r.ctx.Err() != nil { + return + } + r.logger.Warn("webrtc relay disconnected, reconnecting in 3s", "err", err) + select { + case <-r.ctx.Done(): + return + case <-time.After(3 * time.Second): + } + } + }() + }) +} + +// Start connects to Neko and begins relaying video. It blocks until +// the Neko connection drops or ctx is cancelled. Callers should call +// Start in a loop for automatic reconnection. +func (r *Relay) Start(ctx context.Context) error { + r.mu.Lock() + select { + case <-r.ready: + // Previous run closed this channel; create a fresh one for this run. + r.ready = make(chan struct{}) + default: + // Still open (first call or never closed), keep it. + } + r.mu.Unlock() + + token, err := r.nekoLogin(ctx) + if err != nil { + return fmt.Errorf("neko login: %w", err) + } + + wsURL := r.cfg.NekoBaseURL + "/api/ws" + ws, _, err := cws.Dial(ctx, wsURL, &cws.DialOptions{ + HTTPHeader: http.Header{ + "Authorization": []string{"Bearer " + token}, + }, + }) + if err != nil { + return fmt.Errorf("neko ws dial: %w", err) + } + ws.SetReadLimit(1 << 20) + + r.mu.Lock() + r.nekoWS = ws + r.mu.Unlock() + + defer func() { + ws.Close(cws.StatusGoingAway, "done") + r.mu.Lock() + r.nekoWS = nil + r.mu.Unlock() + }() + + initPayload, err := r.waitForEvent(ctx, ws, "system/init") + if err != nil { + return fmt.Errorf("waiting for system/init: %w", err) + } + + var initData struct { + HeartbeatInterval float64 `json:"heartbeat_interval"` + } + if initPayload != nil { + _ = json.Unmarshal(initPayload, &initData) + } + if initData.HeartbeatInterval > 0 { + go func() { + ticker := time.NewTicker(time.Duration(initData.HeartbeatInterval * float64(time.Second))) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := sendWSMsg(ctx, ws, "client/heartbeat", nil); err != nil { + return + } + } + } + }() + } + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + return fmt.Errorf("creating neko peer connection: %w", err) + } + r.mu.Lock() + r.nekoPC = pc + r.mu.Unlock() + defer func() { + pc.Close() + r.mu.Lock() + r.nekoPC = nil + r.mu.Unlock() + }() + + trackReceived := make(chan struct{}, 1) + var forwardOnce sync.Once + + pc.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeVideo { + return + } + r.logger.Info("neko video track received", + "codec", track.Codec().MimeType, + "ssrc", track.SSRC(), + ) + + select { + case trackReceived <- struct{}{}: + default: + } + + forwardOnce.Do(func() { + r.mu.Lock() + select { + case <-r.ready: + default: + close(r.ready) + } + r.mu.Unlock() + + go r.forwardRTP(track) + }) + }) + + disconnected := make(chan struct{}) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + r.logger.Info("neko peer connection state", "state", state.String()) + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + select { + case <-disconnected: + default: + close(disconnected) + } + } + }) + + // Send signal/request to Neko (audio disabled). + reqPayload := json.RawMessage(`{"video":{},"audio":{"disabled":true}}`) + if err := sendWSMsg(ctx, ws, "signal/request", reqPayload); err != nil { + return fmt.Errorf("sending signal/request: %w", err) + } + + offerSDP, err := r.waitForOffer(ctx, ws) + if err != nil { + return fmt.Errorf("waiting for neko SDP offer: %w", err) + } + + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: offerSDP, + }); err != nil { + return fmt.Errorf("setting neko remote desc: %w", err) + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + return fmt.Errorf("creating answer: %w", err) + } + + gatherDone := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(answer); err != nil { + return fmt.Errorf("setting local desc: %w", err) + } + + select { + case <-gatherDone: + case <-ctx.Done(): + return ctx.Err() + } + + answerJSON, _ := json.Marshal(struct { + SDP string `json:"sdp"` + }{SDP: pc.LocalDescription().SDP}) + if err := sendWSMsg(ctx, ws, "signal/answer", answerJSON); err != nil { + return fmt.Errorf("sending signal/answer: %w", err) + } + + r.logger.Info("neko signaling complete, waiting for video track") + + // Background reader for Neko WS (heartbeats, ICE candidates, etc.) + go r.nekoWSLoop(ctx, ws, pc) + + select { + case <-trackReceived: + r.logger.Info("neko video track active, relay ready") + case <-time.After(10 * time.Second): + return fmt.Errorf("timeout waiting for neko video track") + case <-ctx.Done(): + return ctx.Err() + } + + // Block until disconnection or cancellation. + select { + case <-disconnected: + return fmt.Errorf("neko peer connection lost") + case <-ctx.Done(): + return ctx.Err() + } +} + +// HandleWebSocket is the HTTP handler for external WebRTC client signaling. +// Mount as: r.Get("/display/webrtc", relay.HandleWebSocket) +// +// Protocol (two messages total, no trickle ICE): +// +// Client → Server: {"type":"offer","sdp":"v=0\r\n..."} +// Server → Client: {"type":"answer","sdp":"v=0\r\n..."} +// +// After the exchange, WebRTC media flows directly. The WebSocket +// can be closed. +func (r *Relay) HandleWebSocket(w http.ResponseWriter, req *http.Request) { + r.ensureRunning() + + // Wait for the relay to be connected to Neko before accepting the client. + select { + case <-r.Ready(): + case <-time.After(15 * time.Second): + http.Error(w, "relay not ready", http.StatusServiceUnavailable) + return + case <-req.Context().Done(): + return + } + + ws, err := cws.Accept(w, req, nil) + if err != nil { + r.logger.Error("websocket accept failed", "error", err) + return + } + defer ws.Close(cws.StatusNormalClosure, "") + + ctx := req.Context() + + // Read client's SDP offer. + _, data, err := ws.Read(ctx) + if err != nil { + r.logger.Warn("failed to read client offer", "error", err) + return + } + + var offer struct { + Type string `json:"type"` + SDP string `json:"sdp"` + } + if err := json.Unmarshal(data, &offer); err != nil { + r.logger.Warn("invalid client offer", "error", err) + ws.Close(cws.StatusInvalidFramePayloadData, "expected offer") + return + } + if offer.Type != "offer" { + r.logger.Warn("unexpected message type from client", "type", offer.Type) + ws.Close(cws.StatusInvalidFramePayloadData, "expected offer") + return + } + + // Create PeerConnection for this client with the relayed track. + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + r.logger.Error("failed to create client peer connection", "error", err) + ws.Close(cws.StatusInternalError, "peer connection failed") + return + } + + rtpSender, err := pc.AddTrack(r.localTrack) + if err != nil { + pc.Close() + r.logger.Error("failed to add track to client peer connection", "error", err) + ws.Close(cws.StatusInternalError, "add track failed") + return + } + + // Read and discard RTCP from the client (required by pion). + go func() { + buf := make([]byte, 1500) + for { + if _, _, err := rtpSender.Read(buf); err != nil { + return + } + } + }() + + // Register state callback before signaling begins so we never miss + // a terminal transition. Disconnected is transient and can self-recover, + // so only Failed/Closed are treated as terminal. + done := make(chan struct{}) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed { + select { + case <-done: + default: + close(done) + } + } + }) + + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: offer.SDP, + }); err != nil { + pc.Close() + r.logger.Error("failed to set client remote description", "error", err) + ws.Close(cws.StatusInternalError, "sdp failed") + return + } + + answer, err := pc.CreateAnswer(nil) + if err != nil { + pc.Close() + r.logger.Error("failed to create answer for client", "error", err) + ws.Close(cws.StatusInternalError, "answer failed") + return + } + + gatherDone := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(answer); err != nil { + pc.Close() + r.logger.Error("failed to set local description for client", "error", err) + ws.Close(cws.StatusInternalError, "local desc failed") + return + } + + select { + case <-gatherDone: + case <-ctx.Done(): + pc.Close() + return + } + + answerMsg, _ := json.Marshal(map[string]string{ + "type": "answer", + "sdp": pc.LocalDescription().SDP, + }) + if err := ws.Write(ctx, cws.MessageText, answerMsg); err != nil { + pc.Close() + r.logger.Error("failed to send answer to client", "error", err) + return + } + + r.logger.Info("external client connected via WebRTC") + + // Request a keyframe from Neko so the new client gets one immediately. + r.requestKeyframe() + + // Keep the WebSocket open until the PeerConnection closes, so the + // client can detect disconnection cleanly. + select { + case <-done: + case <-ctx.Done(): + } + pc.Close() +} + +// Close tears down the relay. +func (r *Relay) Close() { + r.mu.Lock() + defer r.mu.Unlock() + + if r.nekoPC != nil { + r.nekoPC.Close() + r.nekoPC = nil + } + if r.nekoWS != nil { + r.nekoWS.Close(cws.StatusGoingAway, "shutdown") + r.nekoWS = nil + } +} + +// Ready returns a channel that is closed once the relay has received +// the video track from Neko and is ready to serve clients. +func (r *Relay) Ready() <-chan struct{} { + r.mu.RLock() + defer r.mu.RUnlock() + return r.ready +} + +func (r *Relay) forwardRTP(track *webrtc.TrackRemote) { + for { + pkt, _, err := track.ReadRTP() + if err != nil { + r.logger.Info("neko track read ended", "error", err) + return + } + if err := r.localTrack.WriteRTP(pkt); err != nil { + r.logger.Debug("local track write failed", "error", err) + } + } +} + +func (r *Relay) requestKeyframe() { + r.mu.RLock() + pc := r.nekoPC + r.mu.RUnlock() + if pc == nil { + return + } + + for _, receiver := range pc.GetReceivers() { + t := receiver.Track() + if t != nil && t.Kind() == webrtc.RTPCodecTypeVideo { + _ = pc.WriteRTCP([]rtcp.Packet{ + &rtcp.PictureLossIndication{ + MediaSSRC: uint32(t.SSRC()), + }, + }) + return + } + } +} + +// nekoLogin calls Neko's HTTP login API and returns the bearer token. +func (r *Relay) nekoLogin(ctx context.Context) (string, error) { + payload, _ := json.Marshal(map[string]string{ + "username": r.cfg.NekoUser, + "password": r.cfg.NekoPass, + }) + req, err := http.NewRequestWithContext(ctx, "POST", r.cfg.NekoBaseURL+"/api/login", bytes.NewReader(payload)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("login returned %d", resp.StatusCode) + } + + var result struct { + Token string `json:"token"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + if result.Token == "" { + return "", fmt.Errorf("empty token") + } + return result.Token, nil +} + +// Neko WS message envelope. +type nekoMsg struct { + Event string `json:"event"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +func sendWSMsg(ctx context.Context, ws *cws.Conn, event string, payload json.RawMessage) error { + data, _ := json.Marshal(nekoMsg{Event: event, Payload: payload}) + return ws.Write(ctx, cws.MessageText, data) +} + +func (r *Relay) waitForEvent(ctx context.Context, ws *cws.Conn, event string) (json.RawMessage, error) { + for { + _, data, err := ws.Read(ctx) + if err != nil { + return nil, err + } + var msg nekoMsg + if json.Unmarshal(data, &msg) == nil && msg.Event == event { + return msg.Payload, nil + } + } +} + +func (r *Relay) waitForOffer(ctx context.Context, ws *cws.Conn) (string, error) { + for { + _, data, err := ws.Read(ctx) + if err != nil { + return "", err + } + var msg nekoMsg + if json.Unmarshal(data, &msg) == nil && msg.Event == "signal/provide" { + var provide struct { + SDP string `json:"sdp"` + } + if err := json.Unmarshal(msg.Payload, &provide); err != nil { + return "", fmt.Errorf("parsing signal/provide: %w", err) + } + return provide.SDP, nil + } + } +} + +func (r *Relay) nekoWSLoop(ctx context.Context, ws *cws.Conn, pc *webrtc.PeerConnection) { + for { + _, data, err := ws.Read(ctx) + if err != nil { + return + } + var msg nekoMsg + if json.Unmarshal(data, &msg) != nil { + continue + } + switch msg.Event { + case "signal/candidate": + var candidate webrtc.ICECandidateInit + if json.Unmarshal(msg.Payload, &candidate) == nil { + _ = pc.AddICECandidate(candidate) + } + } + } +}