Skip to content
17 changes: 17 additions & 0 deletions server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
oapi "github.com/onkernel/kernel-images/server/lib/oapi"
"github.com/onkernel/kernel-images/server/lib/recorder"
"github.com/onkernel/kernel-images/server/lib/scaletozero"
"github.com/onkernel/kernel-images/server/lib/webrtcscreen"
)

func main() {
Expand Down Expand Up @@ -135,6 +136,22 @@ func main() {
fs.ServeHTTP(w, r)
})

// WebRTC relay: connects to Neko as a headless viewer and re-serves
// the VP8 video stream to external WebRTC clients via a single
// WebSocket signaling endpoint. The Neko connection is lazy —
// it only starts when the first client connects.
relay, err := webrtcscreen.NewRelay(ctx, webrtcscreen.RelayConfig{
NekoBaseURL: "http://127.0.0.1:8080",
NekoUser: "admin",
NekoPass: adminPassword,
Logger: slogger,
})
if err != nil {
slogger.Error("failed to create webrtc relay", "err", err)
os.Exit(1)
}
r.Get("/display/webrtc", relay.HandleWebSocket)

srv := &http.Server{
Addr: fmt.Sprintf(":%d", config.Port),
Handler: r,
Expand Down
300 changes: 300 additions & 0 deletions server/cmd/webrtc-screenshot/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package main

import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"image/jpeg"
"log/slog"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"

cws "github.com/coder/websocket"
"github.com/onkernel/kernel-images/server/lib/vpxdecoder"
"github.com/pion/rtp/codecs"
"github.com/pion/webrtc/v3"
)

func main() {
serverURL := flag.String("server", "ws://127.0.0.1:10001/display/webrtc", "WebRTC signaling WebSocket URL")
outputPath := flag.String("output", "/tmp/screen.jpg", "Path to write JPEG screenshots")
quality := flag.Int("quality", 85, "JPEG quality (1-100)")
flag.Parse()

if *quality < 1 || *quality > 100 {
fmt.Fprintf(os.Stderr, "error: --quality must be between 1 and 100, got %d\n", *quality)
os.Exit(1)
}

logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo}))

ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

fb := &frameBuffer{
path: *outputPath,
quality: *quality,
logger: logger,
}

for {
err := run(ctx, logger, *serverURL, fb)
if ctx.Err() != nil {
logger.Info("shutting down")
return
}
logger.Warn("connection lost, reconnecting in 2s", "error", err)
select {
case <-time.After(2 * time.Second):
case <-ctx.Done():
return
}
}
}

func run(ctx context.Context, logger *slog.Logger, serverURL string, fb *frameBuffer) error {
connectCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

// Connect to signaling WebSocket.
ws, _, err := cws.Dial(connectCtx, serverURL, nil)
if err != nil {
return fmt.Errorf("ws dial: %w", err)
}
defer ws.Close(cws.StatusGoingAway, "done")

// Create PeerConnection.
pc, err := webrtc.NewPeerConnection(webrtc.Configuration{})
if err != nil {
return fmt.Errorf("new peer connection: %w", err)
}
defer pc.Close()

// We want to receive video only.
if _, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{
Direction: webrtc.RTPTransceiverDirectionRecvonly,
}); err != nil {
return fmt.Errorf("add transceiver: %w", err)
}

trackCh := make(chan *webrtc.TrackRemote, 1)
pc.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) {
if track.Kind() == webrtc.RTPCodecTypeVideo {
select {
case trackCh <- track:
default:
}
}
})

disconnected := make(chan struct{})
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
logger.Info("peer connection state", "state", state.String())
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
select {
case <-disconnected:
default:
close(disconnected)
}
}
})

// Create offer (with all ICE candidates gathered).
offer, err := pc.CreateOffer(nil)
if err != nil {
return fmt.Errorf("create offer: %w", err)
}
gatherDone := webrtc.GatheringCompletePromise(pc)
if err := pc.SetLocalDescription(offer); err != nil {
return fmt.Errorf("set local desc: %w", err)
}
select {
case <-gatherDone:
case <-connectCtx.Done():
return connectCtx.Err()
}

// Send offer to server.
offerMsg, _ := json.Marshal(map[string]string{
"type": "offer",
"sdp": pc.LocalDescription().SDP,
})
if err := ws.Write(connectCtx, cws.MessageText, offerMsg); err != nil {
return fmt.Errorf("send offer: %w", err)
}

// Receive answer.
_, answerData, err := ws.Read(connectCtx)
if err != nil {
return fmt.Errorf("read answer: %w", err)
}
var answer struct {
Type string `json:"type"`
SDP string `json:"sdp"`
}
if err := json.Unmarshal(answerData, &answer); err != nil {
return fmt.Errorf("invalid answer: %w", err)
}
if answer.Type != "answer" {
return fmt.Errorf("unexpected message type: got %q, want \"answer\"", answer.Type)
}

if err := pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: answer.SDP,
}); err != nil {
return fmt.Errorf("set remote desc: %w", err)
}

logger.Info("WebRTC connected, waiting for video track")

// Wait for video track.
var track *webrtc.TrackRemote
select {
case track = <-trackCh:
logger.Info("video track received",
"codec", track.Codec().MimeType,
"ssrc", track.SSRC(),
)
case <-time.After(10 * time.Second):
return fmt.Errorf("timeout waiting for video track")
case <-ctx.Done():
return ctx.Err()
}

// Decode loop: depacketize VP8, decode every frame, write JPEG.
return fb.decodeLoop(ctx, track, disconnected)
}

// frameBuffer holds the VP8 decoder state and handles writing JPEGs.
type frameBuffer struct {
path string
quality int
logger *slog.Logger

mu sync.Mutex
frames int64
}

func (fb *frameBuffer) decodeLoop(ctx context.Context, track *webrtc.TrackRemote, disconnected <-chan struct{}) error {
dec, err := vpxdecoder.New()
if err != nil {
return fmt.Errorf("vpx decoder init: %w", err)
}
defer dec.Close()

fb.mu.Lock()
fb.frames = 0
fb.mu.Unlock()

var (
frameBuf bytes.Buffer
frameStarted bool
)

statsStart := time.Now()

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-disconnected:
return fmt.Errorf("peer connection lost")
default:
}

pkt, _, err := track.ReadRTP()
if err != nil {
return fmt.Errorf("read rtp: %w", err)
}

// Depacketize VP8 from RTP.
vp8Pkt := &codecs.VP8Packet{}
payload, err := vp8Pkt.Unmarshal(pkt.Payload)
if err != nil {
continue
}

// S=1 + PID=0 → start of new frame.
if vp8Pkt.S == 1 && vp8Pkt.PID == 0 {
frameBuf.Reset()
frameStarted = true
}

if !frameStarted {
continue
}

frameBuf.Write(payload)

// Marker bit → last packet of frame.
if !pkt.Marker {
continue
}
frameStarted = false

if frameBuf.Len() == 0 {
continue
}

img, err := dec.Decode(frameBuf.Bytes())
if err != nil {
fb.logger.Debug("decode failed", "error", err, "size", frameBuf.Len())
continue
}

var jpegBuf bytes.Buffer
if err := jpeg.Encode(&jpegBuf, img, &jpeg.Options{Quality: fb.quality}); err != nil {
fb.logger.Warn("jpeg encode failed", "error", err)
continue
}

fb.writeToFile(jpegBuf.Bytes())

fb.mu.Lock()
fb.frames++
count := fb.frames
fb.mu.Unlock()

if count%100 == 0 {
elapsed := time.Since(statsStart)
fb.logger.Info("frame stats",
"frames", count,
"fps", fmt.Sprintf("%.1f", float64(count)/elapsed.Seconds()),
"size_kb", jpegBuf.Len()/1024,
"resolution", fmt.Sprintf("%dx%d", img.Rect.Dx(), img.Rect.Dy()),
)
}
}
}

func (fb *frameBuffer) writeToFile(data []byte) {
dir := filepath.Dir(fb.path)
tmp, err := os.CreateTemp(dir, ".screenshot-*.tmp")
if err != nil {
fb.logger.Warn("failed to create temp file", "error", err)
return
}
tmpName := tmp.Name()

if _, err := tmp.Write(data); err != nil {
tmp.Close()
os.Remove(tmpName)
return
}
if err := tmp.Close(); err != nil {
os.Remove(tmpName)
return
}
if err := os.Rename(tmpName, fb.path); err != nil {
fb.logger.Warn("rename failed", "error", err)
os.Remove(tmpName)
return
}
}
23 changes: 21 additions & 2 deletions server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ require (
github.com/m1k1o/neko/server v0.0.0-20251008185748-46e2fc7d3866
github.com/nrednav/cuid2 v1.1.0
github.com/oapi-codegen/runtime v1.1.2
github.com/pion/rtcp v1.2.15
github.com/pion/rtp v1.8.21
github.com/pion/webrtc/v3 v3.3.6
github.com/samber/lo v1.52.0
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.40.0
golang.org/x/sync v0.17.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.39.0
golang.org/x/term v0.37.0
)
Expand Down Expand Up @@ -72,6 +75,20 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pion/datachannel v1.5.10 // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/ice/v2 v2.3.38 // indirect
github.com/pion/interceptor v0.1.40 // indirect
github.com/pion/logging v0.2.4 // indirect
github.com/pion/mdns v0.0.12 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/sctp v1.8.39 // indirect
github.com/pion/sdp/v3 v3.0.15 // indirect
github.com/pion/srtp/v2 v2.0.20 // indirect
github.com/pion/stun v0.6.1 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
Expand All @@ -80,6 +97,7 @@ require (
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
Expand All @@ -90,7 +108,8 @@ require (
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
golang.org/x/crypto v0.43.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
Loading
Loading