From 9999034d3f04f217e602607e678f7af8a7554a4b Mon Sep 17 00:00:00 2001 From: Kayne Tu Date: Tue, 16 Jun 2026 14:27:00 -0700 Subject: [PATCH 1/4] feat(slinky-slurm): Add slinky slurm conformance validator and cuj demo doc --- demos/README.md | 1 + demos/cuj1-slinky-slurm.md | 279 ++++ docs/contributor/validator.md | 22 + docs/user/component-catalog.md | 2 +- go.mod | 3 + go.sum | 8 + recipes/checks/slinky-slurm/health-check.yaml | 2 +- .../evidence/gb200-eks-ubuntu-training.yaml | 14 + recipes/evidence/h100-gke-cos-training.yaml | 14 + .../h100-eks-ubuntu-training-slurm.yaml | 3 + .../overlays/h100-gke-cos-training-slurm.yaml | 3 + .../overlays/h100-kind-training-slurm.yaml | 5 + recipes/validators/catalog.yaml | 7 + validators/conformance/main.go | 1 + validators/conformance/pod_exec.go | 100 ++ validators/conformance/pod_exec_test.go | 216 +++ .../conformance/slinky_slurm_health_check.go | 375 +++++ .../slinky_slurm_health_check_test.go | 465 ++++++ .../github.com/gorilla/websocket/.gitignore | 25 + vendor/github.com/gorilla/websocket/AUTHORS | 9 + vendor/github.com/gorilla/websocket/LICENSE | 22 + vendor/github.com/gorilla/websocket/README.md | 32 + vendor/github.com/gorilla/websocket/client.go | 517 +++++++ .../gorilla/websocket/compression.go | 152 ++ vendor/github.com/gorilla/websocket/conn.go | 1246 +++++++++++++++++ vendor/github.com/gorilla/websocket/doc.go | 227 +++ vendor/github.com/gorilla/websocket/join.go | 42 + vendor/github.com/gorilla/websocket/json.go | 60 + vendor/github.com/gorilla/websocket/mask.go | 55 + .../github.com/gorilla/websocket/mask_safe.go | 16 + .../github.com/gorilla/websocket/prepared.go | 102 ++ vendor/github.com/gorilla/websocket/proxy.go | 104 ++ vendor/github.com/gorilla/websocket/server.go | 373 +++++ vendor/github.com/gorilla/websocket/util.go | 298 ++++ .../moby/spdystream/CONTRIBUTING.md | 13 + vendor/github.com/moby/spdystream/LICENSE | 202 +++ vendor/github.com/moby/spdystream/MAINTAINERS | 40 + vendor/github.com/moby/spdystream/NOTICE | 17 + vendor/github.com/moby/spdystream/README.md | 77 + .../github.com/moby/spdystream/connection.go | 1000 +++++++++++++ vendor/github.com/moby/spdystream/handlers.go | 52 + vendor/github.com/moby/spdystream/priority.go | 114 ++ .../github.com/moby/spdystream/spdy/LICENSE | 27 + .../github.com/moby/spdystream/spdy/PATENTS | 22 + .../moby/spdystream/spdy/dictionary.go | 187 +++ .../moby/spdystream/spdy/options.go | 25 + .../github.com/moby/spdystream/spdy/read.go | 382 +++++ .../github.com/moby/spdystream/spdy/types.go | 308 ++++ .../github.com/moby/spdystream/spdy/write.go | 355 +++++ vendor/github.com/moby/spdystream/stream.go | 345 +++++ vendor/github.com/moby/spdystream/utils.go | 32 + .../golang.org/x/net/internal/socks/client.go | 168 +++ .../golang.org/x/net/internal/socks/socks.go | 317 +++++ vendor/golang.org/x/net/proxy/dial.go | 54 + vendor/golang.org/x/net/proxy/direct.go | 31 + vendor/golang.org/x/net/proxy/per_host.go | 153 ++ vendor/golang.org/x/net/proxy/proxy.go | 149 ++ vendor/golang.org/x/net/proxy/socks5.go | 42 + vendor/golang.org/x/net/websocket/client.go | 139 ++ vendor/golang.org/x/net/websocket/dial.go | 29 + vendor/golang.org/x/net/websocket/hybi.go | 583 ++++++++ vendor/golang.org/x/net/websocket/server.go | 113 ++ .../golang.org/x/net/websocket/websocket.go | 449 ++++++ .../apimachinery/pkg/util/httpstream/doc.go | 20 + .../pkg/util/httpstream/httpstream.go | 201 +++ .../pkg/util/httpstream/spdy/doc.go | 20 + .../pkg/util/httpstream/spdy/spdy.go | 236 ++++ .../pkg/util/remotecommand/constants.go | 67 + .../client-go/tools/remotecommand/OWNERS | 10 + .../client-go/tools/remotecommand/doc.go | 20 + .../tools/remotecommand/errorstream.go | 55 + .../client-go/tools/remotecommand/fallback.go | 60 + .../client-go/tools/remotecommand/reader.go | 41 + .../tools/remotecommand/remotecommand.go | 59 + .../client-go/tools/remotecommand/resize.go | 34 + .../client-go/tools/remotecommand/spdy.go | 176 +++ .../client-go/tools/remotecommand/v1.go | 164 +++ .../client-go/tools/remotecommand/v2.go | 205 +++ .../client-go/tools/remotecommand/v3.go | 117 ++ .../client-go/tools/remotecommand/v4.go | 125 ++ .../client-go/tools/remotecommand/v5.go | 37 + .../tools/remotecommand/websocket.go | 537 +++++++ .../k8s.io/client-go/transport/spdy/spdy.go | 317 +++++ .../transport/websocket/roundtripper.go | 224 +++ vendor/k8s.io/client-go/util/exec/exec.go | 52 + vendor/k8s.io/streaming/LICENSE | 202 +++ vendor/k8s.io/streaming/pkg/httpstream/doc.go | 19 + .../streaming/pkg/httpstream/httpstream.go | 201 +++ .../pkg/httpstream/spdy/connection.go | 206 +++ .../pkg/httpstream/spdy/roundtripper.go | 572 ++++++++ .../streaming/pkg/httpstream/spdy/upgrade.go | 120 ++ .../streaming/pkg/httpstream/wsstream/conn.go | 466 ++++++ .../streaming/pkg/httpstream/wsstream/doc.go | 69 + .../pkg/httpstream/wsstream/stream.go | 193 +++ .../k8s.io/streaming/pkg/runtime/runtime.go | 62 + vendor/modules.txt | 23 + 96 files changed, 15133 insertions(+), 2 deletions(-) create mode 100644 demos/cuj1-slinky-slurm.md create mode 100644 validators/conformance/pod_exec.go create mode 100644 validators/conformance/pod_exec_test.go create mode 100644 validators/conformance/slinky_slurm_health_check.go create mode 100644 validators/conformance/slinky_slurm_health_check_test.go create mode 100644 vendor/github.com/gorilla/websocket/.gitignore create mode 100644 vendor/github.com/gorilla/websocket/AUTHORS create mode 100644 vendor/github.com/gorilla/websocket/LICENSE create mode 100644 vendor/github.com/gorilla/websocket/README.md create mode 100644 vendor/github.com/gorilla/websocket/client.go create mode 100644 vendor/github.com/gorilla/websocket/compression.go create mode 100644 vendor/github.com/gorilla/websocket/conn.go create mode 100644 vendor/github.com/gorilla/websocket/doc.go create mode 100644 vendor/github.com/gorilla/websocket/join.go create mode 100644 vendor/github.com/gorilla/websocket/json.go create mode 100644 vendor/github.com/gorilla/websocket/mask.go create mode 100644 vendor/github.com/gorilla/websocket/mask_safe.go create mode 100644 vendor/github.com/gorilla/websocket/prepared.go create mode 100644 vendor/github.com/gorilla/websocket/proxy.go create mode 100644 vendor/github.com/gorilla/websocket/server.go create mode 100644 vendor/github.com/gorilla/websocket/util.go create mode 100644 vendor/github.com/moby/spdystream/CONTRIBUTING.md create mode 100644 vendor/github.com/moby/spdystream/LICENSE create mode 100644 vendor/github.com/moby/spdystream/MAINTAINERS create mode 100644 vendor/github.com/moby/spdystream/NOTICE create mode 100644 vendor/github.com/moby/spdystream/README.md create mode 100644 vendor/github.com/moby/spdystream/connection.go create mode 100644 vendor/github.com/moby/spdystream/handlers.go create mode 100644 vendor/github.com/moby/spdystream/priority.go create mode 100644 vendor/github.com/moby/spdystream/spdy/LICENSE create mode 100644 vendor/github.com/moby/spdystream/spdy/PATENTS create mode 100644 vendor/github.com/moby/spdystream/spdy/dictionary.go create mode 100644 vendor/github.com/moby/spdystream/spdy/options.go create mode 100644 vendor/github.com/moby/spdystream/spdy/read.go create mode 100644 vendor/github.com/moby/spdystream/spdy/types.go create mode 100644 vendor/github.com/moby/spdystream/spdy/write.go create mode 100644 vendor/github.com/moby/spdystream/stream.go create mode 100644 vendor/github.com/moby/spdystream/utils.go create mode 100644 vendor/golang.org/x/net/internal/socks/client.go create mode 100644 vendor/golang.org/x/net/internal/socks/socks.go create mode 100644 vendor/golang.org/x/net/proxy/dial.go create mode 100644 vendor/golang.org/x/net/proxy/direct.go create mode 100644 vendor/golang.org/x/net/proxy/per_host.go create mode 100644 vendor/golang.org/x/net/proxy/proxy.go create mode 100644 vendor/golang.org/x/net/proxy/socks5.go create mode 100644 vendor/golang.org/x/net/websocket/client.go create mode 100644 vendor/golang.org/x/net/websocket/dial.go create mode 100644 vendor/golang.org/x/net/websocket/hybi.go create mode 100644 vendor/golang.org/x/net/websocket/server.go create mode 100644 vendor/golang.org/x/net/websocket/websocket.go create mode 100644 vendor/k8s.io/apimachinery/pkg/util/httpstream/doc.go create mode 100644 vendor/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go create mode 100644 vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/doc.go create mode 100644 vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/spdy.go create mode 100644 vendor/k8s.io/apimachinery/pkg/util/remotecommand/constants.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/OWNERS create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/doc.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/errorstream.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/fallback.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/reader.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/remotecommand.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/resize.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/spdy.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/v1.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/v2.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/v3.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/v4.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/v5.go create mode 100644 vendor/k8s.io/client-go/tools/remotecommand/websocket.go create mode 100644 vendor/k8s.io/client-go/transport/spdy/spdy.go create mode 100644 vendor/k8s.io/client-go/transport/websocket/roundtripper.go create mode 100644 vendor/k8s.io/client-go/util/exec/exec.go create mode 100644 vendor/k8s.io/streaming/LICENSE create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/doc.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/httpstream.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/spdy/connection.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/spdy/roundtripper.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/spdy/upgrade.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/wsstream/conn.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/wsstream/doc.go create mode 100644 vendor/k8s.io/streaming/pkg/httpstream/wsstream/stream.go create mode 100644 vendor/k8s.io/streaming/pkg/runtime/runtime.go diff --git a/demos/README.md b/demos/README.md index b5b73e2b5..143226fe4 100644 --- a/demos/README.md +++ b/demos/README.md @@ -7,6 +7,7 @@ Runbooks for testing and demonstrating AICR end-to-end workflows on live cluster | Demo | Description | |------|-------------| | [cuj1-training.md](cuj1-training.md) | CUJ1 (training) - EKS + GKE end-to-end, plus a config-driven GKE + signed-evidence variant | +| [cuj1-slinky-slurm.md](cuj1-slinky-slurm.md) | CUJ1 - Slinky Slurm on EKS / GKE / Kind (recipe → bundle → validate → `srun`) | | [cuj2-inference.md](cuj2-inference.md) | CUJ2 (inference) - EKS + GKE end-to-end with the Dynamo platform | | [cuj2-demo.md](cuj2-demo.md) | CUJ2 (inference) - Annotated slide-style demo walkthrough (training vs inference) | | [recipe-data-architecture.md](recipe-data-architecture.md) | Recipe metadata system: inheritance, criteria matching, deployment order, runtime external data | diff --git a/demos/cuj1-slinky-slurm.md b/demos/cuj1-slinky-slurm.md new file mode 100644 index 000000000..919d12f5a --- /dev/null +++ b/demos/cuj1-slinky-slurm.md @@ -0,0 +1,279 @@ +# AICR - Critical User Journey (CUJ) 1 — Slinky Slurm + +End-to-end walkthrough: **generate recipe (Query Mode) → bundle → deploy → validate → `srun` smoke job**. + +Slurm leaves are built from criteria flags (`--service`, `--platform slurm`, …), not from `aicr snapshot` — snapshot intake for Slurm is not supported today. See [Query Mode](../docs/user/cli-reference.md#aicr-recipe) in the CLI reference. + +## Assumptions + +- `kubectl` is configured for the target cluster. +- GPU leaves assume H100 nodes with drivers (or Kind for the CPU-only path). +- Node pools use a `**nodeGroup`** label (adjust if your cluster uses different keys). +- Inspect taints before bundling: `kubectl get nodes -o custom-columns=NAME:.metadata.name,GROUP:.metadata.labels.nodeGroup,TAINTS:.spec.taints` + +## Workflow + +```text + aicr recipe aicr bundle ./deploy.sh aicr validate srun smoke + (Query Mode) ──▶ (scheduling) ──▶ (install) ──▶ (phases) ──▶ (manual) +``` + +1. **Generate recipe (Query Mode)** — `aicr recipe --service … --platform slurm` resolves a slurm leaf overlay to `recipe.yaml`. +2. **Generate bundle** — apply `--system-*` / `--accelerated-*` scheduling and optional `--set` / `--set-json` on `slinkyslurm`. +3. **Install** — run `deploy.sh`; cert-manager and Slinky operator come up, then the cluster chart in `slurm`. +4. **Validate** — run `deployment` (Chainsaw component health) and `conformance` (`slinky-slurm-health` from the login pod). **Performance validation is not supported yet** on slurm leaves. +5. **Smoke job** — `kubectl exec` into the login pod and run `srun` to confirm scheduling. + +## Generate Recipe (Query Mode) + +Pick the row that matches your cluster. Each resolves to a slurm leaf with three inline Slinky components: `slinky-slurm-operator-crds`, `slinky-slurm-operator`, and `slinky-slurm`. + + +| Cloud | Command | Leaf overlay | +| -------- | ------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- | +| **EKS** | `aicr recipe --service eks --accelerator h100 --intent training --os ubuntu --platform slurm -o recipe.yaml` | `h100-eks-ubuntu-training-slurm` | +| **GKE** | `aicr recipe --service gke --accelerator h100 --intent training --os cos --platform slurm -o recipe.yaml` | `h100-gke-cos-training-slurm` | +| **Kind** | `aicr recipe --service kind --accelerator h100 --intent training --platform slurm -o recipe.yaml` | `h100-kind-training-slurm` (CPU-only NodeSet; no GPU GRES) | + + +H100 cloud leaves bake in `Gres=gpu:h100:8` and matching `nvidia.com/gpu: 8` slurmd limits so `srun --gres=gpu:N` works after deploy. + +## Generate Bundle + +### Scheduling model + +AICR injects placement from bundle flags using each component's registry paths: + + +| Flag | Typical targets | +| --------------------------------------------------------------- | --------------------------------------------------- | +| `--system-node-selector` / `--system-node-toleration` | cert-manager, **slurm-operator**, prometheus, … | +| `--accelerated-node-selector` / `--accelerated-node-toleration` | `**nodesets.slinky`** (slurmd workers) | +| `--set-json slinkyslurm:…` | Per-leaf overrides on the cluster chart (see below) | + + +**Registry default for `slinky-slurm`:** `controller`, `restapi`, and `loginsets.slinky` use the **system** paths; `nodesets.slinky` uses **accelerated** paths. On split clusters (system pool + GPU pool), override the control plane onto the pool you want with `--set-json` (runs **after** selector injection and wins on those paths). + +**Operator note:** slurm-operator chart v1.1.0 ignores `nodeSelector`; it schedules from **tolerations** only. On EKS, include **both** `NoSchedule` and `NoExecute` for each taint key — nodes often carry both effects. + +**Override aliases:** `slinkyslurm`, `slurmcluster` (cluster chart); `slurm`, `slurmoperator` (operator chart). See `valueOverrideKeys` in `recipes/registry.yaml`. + +**Scalar vs structured overrides:** + +- `--set slinkyslurm:nodesets.slinky.replicas=2` — replicas, simple scalars. +- `--set-json slinkyslurm:controller.podSpec=…` — full `nodeSelector` / `tolerations` objects (required when overriding system-injected scheduling on control-plane paths). + +### EKS (dual taints: `system-workload` / `worker-workload`) + +Example layout: 3× `system-worker`, 1× `gpu-worker`. Operator + platform stack on system nodes; slurmd on GPU; controller / login / restapi pinned to GPU via `--set-json`. + +```shell +WORKER_TOLS='[{"key":"dedicated","operator":"Equal","value":"worker-workload","effect":"NoSchedule"},{"key":"dedicated","operator":"Equal","value":"worker-workload","effect":"NoExecute"}]' + +aicr bundle \ + --recipe recipe.yaml \ + --deployer helm \ + --system-node-selector nodeGroup=system-worker \ + --system-node-toleration dedicated=system-workload:NoSchedule \ + --system-node-toleration dedicated=system-workload:NoExecute \ + --accelerated-node-selector nodeGroup=gpu-worker \ + --accelerated-node-toleration dedicated=worker-workload:NoSchedule \ + --accelerated-node-toleration dedicated=worker-workload:NoExecute \ + --storage-class \ + --set slinkyslurm:nodesets.slinky.replicas=1 \ + --set-json "slinkyslurm:controller.podSpec={\"nodeSelector\":{\"nodeGroup\":\"gpu-worker\"},\"tolerations\":${WORKER_TOLS}}" \ + --set-json "slinkyslurm:restapi.podSpec={\"nodeSelector\":{\"nodeGroup\":\"gpu-worker\"},\"tolerations\":${WORKER_TOLS}}" \ + --set-json "slinkyslurm:loginsets.slinky.podSpec={\"nodeSelector\":{\"nodeGroup\":\"gpu-worker\"},\"tolerations\":${WORKER_TOLS}}" \ + --output bundle +``` + +Set `replicas` to your GPU node count when you have multiple workers. + +### GKE (system + cpu + gpu pools; GPU taint only) + +Example layout: 3× `system-worker` (no taints), 1× `cpu-worker` (no taints), 2× `gpu-worker` (`dedicated=gpu-workload:NoSchedule`). Control plane on **cpu-worker**; slurmd on **gpu-worker**. + +```shell +aicr bundle \ + --recipe recipe.yaml \ + --deployer helm \ + --system-node-selector nodeGroup=system-worker \ + --accelerated-node-selector nodeGroup=gpu-worker \ + --accelerated-node-toleration dedicated=gpu-workload:NoSchedule \ + --storage-class \ + --set slinkyslurm:nodesets.slinky.replicas=2 \ + --set-json 'slinkyslurm:controller.podSpec={"nodeSelector":{"nodeGroup":"cpu-worker"}}' \ + --set-json 'slinkyslurm:restapi.podSpec={"nodeSelector":{"nodeGroup":"cpu-worker"}}' \ + --set-json 'slinkyslurm:loginsets.slinky.podSpec={"nodeSelector":{"nodeGroup":"cpu-worker"}}' \ + --output bundle +``` + +GKE system nodes should **not** carry custom taints (konnectivity and other managed pods break). No `--system-node-toleration` on GKE when system/cpu pools are untainted. + +Optional: `--accelerated-node-toleration nvidia.com/gpu=present:NoSchedule` (harmless if that taint is absent). + +### Kind (CPU-only smoke / CI) + +No GPU pools or taints; omit accelerated flags unless your Kind config adds them. + +```shell +aicr bundle \ + --recipe recipe.yaml \ + --deployer helm \ + --output bundle +``` + +For automated no-GPU checks, see `make kwok-e2e` / `make check-health COMPONENT=slinky-slurm` in the repo Makefile. + +### Storage class + +Set `--storage-class` to a StorageClass that exists (`kubectl get storageclass`). The kube-prometheus-stack overlay uses a `volumeClaimTemplate` without a default `storageClassName`; a missing/default SC leaves PVCs Pending. + +## Install Bundle + +```shell +cd ./bundle && chmod +x deploy.sh && ./deploy.sh +``` + +Deploy order: `cert-manager` → `slinky-slurm-operator-crds` → `slinky-slurm-operator` → `slinky-slurm`. + +```shell +kubectl rollout status -n slinky deploy/slurm-operator +kubectl get pods -n slurm +kubectl wait --for=jsonpath='{.status.conditions[?(@.type=="Available")].status}'=True \ + -n slurm deploy/slinky-slurm-login-slinky --timeout=10m +``` + +If nodewright is already installed, skip those sections in `deploy.sh` to avoid upgrade conflicts. + +## Validate Cluster + +Use **deployment** and **conformance**. Performance validation is **not supported yet** on slurm leaves — there is no Slurm-native NCCL (or equivalent) check in AICR today; a K8s Pod benchmark would bypass slurmd and is the wrong path on a Slinky-managed cluster. + + +| Phase | What it checks | +| ------------- | ---------------------------------------------------------------------------------------------------------------------- | +| `deployment` | Component Chainsaw health (CRs, Deployments, DaemonSets ready), including `slinky-slurm` readiness (long retry budget) | +| `conformance` | `slinky-slurm-health`: `scontrol ping`, idle/mix node gate, bounded `srun --immediate=5 --time=0:01 hostname` | +| `performance` | **Not supported yet** on slurm leaves | +| `all` | Runs deployment → conformance → performance in sequence; the performance step has nothing to run on slurm leaves | + + +### All phases + +```shell +aicr validate \ + --recipe recipe.yaml \ + --phase all \ + --output report.json +``` + +Prefer `--phase deployment --phase conformance` when you only want the supported checks. + +### Specific phases + +```shell +# After deploy.sh — component + CR readiness (Chainsaw) +aicr validate \ + --recipe recipe.yaml \ + --phase deployment \ + --output report-deployment.json + +# Slurm behavior from login pod (conformance Job) +aicr validate \ + --recipe recipe.yaml \ + --phase conformance \ + --output report-conformance.json + +# Both — common after install +aicr validate \ + --recipe recipe.yaml \ + --phase deployment \ + --phase conformance \ + --output report.json +``` + +### Scheduling flags on validate + +When validate captures cluster state inline (no `-s`), pass `**--node-selector**` and `**--toleration**` so the snapshot agent Job can schedule on tainted nodes. Match your **system** pool (not the GPU pool) unless you intend to run the agent on GPU nodes. + +**EKS example** (agent on system nodes): + +```shell +aicr validate \ + --recipe recipe.yaml \ + --node-selector nodeGroup=system-worker \ + --toleration dedicated=system-workload:NoSchedule \ + --toleration dedicated=system-workload:NoExecute \ + --phase deployment \ + --phase conformance \ + --output report.json +``` + +**GKE example** (untainted system pool; `--toleration` optional): + +```shell +aicr validate \ + --recipe recipe.yaml \ + --node-selector nodeGroup=system-worker \ + --toleration dedicated=gpu-workload:NoSchedule \ + --phase deployment \ + --phase conformance \ + --output report.json +``` + +`--toleration` on validate applies to inner conformance/deployment Jobs; pair it with `--node-selector` when the default GPU auto-selector (`nvidia.com/gpu.present=true`) would land on tainted nodes you cannot tolerate. + +Readiness constraints (K8s version, OS, …) still run before any phase; they use measurements from the inline capture path above. + +## Run Job + +SSH is disabled by default on the login chart; use `kubectl exec`. + +```shell +kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- sinfo +kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- \ + srun --immediate=5 --time=0:01 hostname +``` + +Multi-node (when `replicas >= 2`): + +```shell +kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- srun -N2 hostname +``` + +GPU GRES smoke (H100 cloud leaves): + +```shell +kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- \ + sh -c 'srun -N2 --gres=gpu:8 nvidia-smi -L | sort -u | wc -l' +``` + +## Cleanup + +Cluster instance only (keep operator + CRDs): + +```shell +helm uninstall slinky-slurm -n slurm +``` + +Full Slurm stack: + +```shell +helm uninstall slinky-slurm -n slurm +helm uninstall slinky-slurm-operator -n slinky +helm uninstall slinky-slurm-operator-crds -n slinky +kubectl delete ns slurm slinky --ignore-not-found +``` + +Helm does not remove CRDs or PVCs by default; delete manually when you need a clean re-install. + +## Success + +- `deployment` + `conformance` phases pass in the CTRF report. +- `sinfo` shows NodeSet nodes idle. +- `srun hostname` returns worker hostnames. +- On GPU leaves, `srun --gres=gpu:8 nvidia-smi -L` reaches all GPUs per node. + +> Multi-node NCCL via `srun` + Pyxis/Enroot is the natural Slurm-native performance path; it is out of scope for this smoke CUJ and not covered by `aicr validate --phase performance` today. + diff --git a/docs/contributor/validator.md b/docs/contributor/validator.md index 8212f587e..bf074b111 100644 --- a/docs/contributor/validator.md +++ b/docs/contributor/validator.md @@ -732,6 +732,28 @@ make check-health-all # everything in recipes/checks/ make validate-local RECIPE=recipe.yaml # full pipeline in Kind ``` +### Timeout budgeting + +During `aicr validate --phase deployment`, registry health checks in +`recipes/checks//health-check.yaml` run in-process inside +the `expected-resources` check (`validators/chainsaw/inprocess.go`). + +A Test's `spec.timeouts.assert` is the **whole-Test budget** — one +deadline shared across every step and retry. Slurm's +[`health-check.yaml`](https://github.com/NVIDIA/aicr/blob/main/recipes/checks/slinky-slurm/health-check.yaml) +uses `assert: 7m` so workload-readiness steps can converge before the +pod-phase guard runs. + +The `expected-resources` catalog timeout (8m in +`recipes/validators/catalog.yaml`) is the **outer** envelope. It must +exceed the longest in-tree `assert` value plus headroom for +pre-chainsaw work, chainsaw teardown, and log flush +(`defaults.JobEnvelopeMargin`). If assert runs too close to that +catalog deadline, the Job can SIGKILL the pod before chainsaw reports +the failing step — operators see truncated output instead of a useful +failure. Raise the catalog `timeout` in tandem when you need a longer +assert budget (`TestExpectedResourcesCatalogEnvelope` guards this). + ## Constraint evaluation algorithm `pkg/constraints` is shared by surface 1, surface 2's recipe diff --git a/docs/user/component-catalog.md b/docs/user/component-catalog.md index e54a57cf8..55d7930b3 100644 --- a/docs/user/component-catalog.md +++ b/docs/user/component-catalog.md @@ -47,7 +47,7 @@ Not every component appears in every recipe. The recipe engine selects component - **Base components** (cert-manager, kube-prometheus-stack) appear in most recipes. - **Cloud-specific components** (aws-efa, aws-ebs-csi-driver) are added when the service matches. - **Intent-specific components** (agentgateway, agentgateway-crds) are added based on workload intent (e.g., inference recipes include the inference gateway). -- **Platform-specific components** (slinky-slurm-operator, slinky-slurm, kubeflow-trainer, dynamo-platform) are added when the recipe selects a matching `--platform`. For `--platform slurm`, all three Slinky pieces (`slinky-slurm-operator-crds`, `slinky-slurm-operator`, `slinky-slurm`) are declared inline per slurm leaf overlay — the same shape `dynamo-platform` uses across `*-inference-dynamo` leaves. Leaves that want the operator only inline the CRDs + operator and omit the `slinky-slurm` componentRef. +- **Platform-specific components** (slinky-slurm-operator, slinky-slurm, kubeflow-trainer, dynamo-platform) are added when the recipe selects a matching `--platform`. For `--platform slurm`, all three Slinky pieces (`slinky-slurm-operator-crds`, `slinky-slurm-operator`, `slinky-slurm`) are declared inline per slurm leaf overlay — the same shape `dynamo-platform` uses across `*-inference-dynamo` leaves. Leaves that want the operator only inline the CRDs + operator and omit the `slinky-slurm` componentRef. For an end-to-end walkthrough (recipe → bundle → install → validate → `srun` smoke job on EKS, GKE, or Kind), see [`demos/cuj1-slinky-slurm.md`](https://github.com/NVIDIA/aicr/blob/main/demos/cuj1-slinky-slurm.md). - **Accelerator/OS-specific tuning** (nodewright-customizations, nvidia-dra-driver-gpu) varies by hardware and OS combination. ### NFD Topology Updater diff --git a/go.mod b/go.mod index 380d8fbc2..ae3bc40d5 100644 --- a/go.mod +++ b/go.mod @@ -128,6 +128,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect github.com/googleapis/gax-go/v2 v2.22.0 // indirect + github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.8 // indirect @@ -145,6 +146,7 @@ require ( github.com/mattn/go-isatty v0.0.22 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/moby/spdystream v0.5.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect @@ -202,6 +204,7 @@ require ( k8s.io/klog/v2 v2.140.0 // indirect k8s.io/kube-openapi v0.0.0-20260603220949-865597e52e25 // indirect k8s.io/kubernetes v1.36.2 // indirect + k8s.io/streaming v0.36.2 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.4.0 // indirect diff --git a/go.sum b/go.sum index 9872fe876..466c6b873 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYW github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/aquilax/truncate v1.0.1 h1:+hqGSRxnQ0F5wdPCGbi1XW4ipQ6vzpli23V9Rd+I/mc= github.com/aquilax/truncate v1.0.1/go.mod h1:BeMESIDMlvlS3bmg4BVvBbbZUNwWtS8uzYPAKXwwhLw= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go-v2 v1.42.0 h1:XvXMJTkFQtpBKIWZnmr9ZEOc2InWM2yldjXEJ/bymhA= @@ -253,6 +255,8 @@ github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= +github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 h1:5VipnvEpbqr2gA2VbM+nYVbkIF28c5ZQfqCBQ5g2xfk= @@ -331,6 +335,8 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/moby/spdystream v0.5.1 h1:9sNYeYZUcci9R6/w7KDaFWEWeV4LStVG78Mpyq/Zm/Y= +github.com/moby/spdystream v0.5.1/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg= github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= @@ -583,6 +589,8 @@ k8s.io/kube-openapi v0.0.0-20260603220949-865597e52e25 h1:mPMaPMpBij2V1Wv/fR+HW1 k8s.io/kube-openapi v0.0.0-20260603220949-865597e52e25/go.mod h1:V/QaCUYDa+0QpcHhVVc5l99Uz56wEMEXBSj9oCDkNDY= k8s.io/kubernetes v1.36.2 h1:qsCug7E1dMVO+rNuCENKG64Z7SVS/fTqDHGI1/NCmTg= k8s.io/kubernetes v1.36.2/go.mod h1:MLdeJ3qw2CWH9BFml5GvptxQVQckz54fJOZ/WuixpFE= +k8s.io/streaming v0.36.2 h1:NSKthPPg9UFSKsRauVJUVGH2Dvn8fhKmY4qrMkw/p98= +k8s.io/streaming v0.36.2/go.mod h1:z6fV3D+NVkoeqRMtWwlUZK6U17SY/LqNzOxWL6GyR/s= k8s.io/utils v0.0.0-20260507154919-ff6756f316d2 h1:wU4tMEhLGgIbLvXQb1cfN+EcM0wf7zC6CPF+C79jroc= k8s.io/utils v0.0.0-20260507154919-ff6756f316d2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= oras.land/oras-go/v2 v2.6.1 h1:bonOEkjLfp8tt6qXWRRWP6p1F+9octchOf2EqnWB4Zs= diff --git a/recipes/checks/slinky-slurm/health-check.yaml b/recipes/checks/slinky-slurm/health-check.yaml index 7b70c895d..0ab244151 100644 --- a/recipes/checks/slinky-slurm/health-check.yaml +++ b/recipes/checks/slinky-slurm/health-check.yaml @@ -27,7 +27,7 @@ metadata: name: slinky-slurm-health-check spec: timeouts: - assert: 10m + assert: 7m steps: - name: validate-controller-cr try: diff --git a/recipes/evidence/gb200-eks-ubuntu-training.yaml b/recipes/evidence/gb200-eks-ubuntu-training.yaml index 182e5feb5..00fb65fae 100644 --- a/recipes/evidence/gb200-eks-ubuntu-training.yaml +++ b/recipes/evidence/gb200-eks-ubuntu-training.yaml @@ -1,3 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + attestations: - attestedAt: 2026-06-15T16:58:03Z bundle: diff --git a/recipes/evidence/h100-gke-cos-training.yaml b/recipes/evidence/h100-gke-cos-training.yaml index a8a08c9d9..40d8a3ff6 100644 --- a/recipes/evidence/h100-gke-cos-training.yaml +++ b/recipes/evidence/h100-gke-cos-training.yaml @@ -1,3 +1,17 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + attestations: - attestedAt: 2026-06-15T15:22:56Z bundle: diff --git a/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml b/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml index c15271b24..885cbe4a3 100644 --- a/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml +++ b/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml @@ -95,6 +95,9 @@ spec: # EFA libfabric stack already present on the parent EKS leaf. # Deployment and conformance checks are inherited unchanged. validation: + conformance: + checks: + - slinky-slurm-health performance: checks: [] constraints: [] diff --git a/recipes/overlays/h100-gke-cos-training-slurm.yaml b/recipes/overlays/h100-gke-cos-training-slurm.yaml index 688326312..bd59326c6 100644 --- a/recipes/overlays/h100-gke-cos-training-slurm.yaml +++ b/recipes/overlays/h100-gke-cos-training-slurm.yaml @@ -98,6 +98,9 @@ spec: # gke-nccl-tcpxo. Deployment and conformance checks are inherited # unchanged. validation: + conformance: + checks: + - slinky-slurm-health performance: checks: [] constraints: [] diff --git a/recipes/overlays/h100-kind-training-slurm.yaml b/recipes/overlays/h100-kind-training-slurm.yaml index 0792876b1..f97defd67 100644 --- a/recipes/overlays/h100-kind-training-slurm.yaml +++ b/recipes/overlays/h100-kind-training-slurm.yaml @@ -64,3 +64,8 @@ spec: dependencyRefs: - slinky-slurm-operator - slinky-slurm-operator-crds + + validation: + conformance: + checks: + - slinky-slurm-health diff --git a/recipes/validators/catalog.yaml b/recipes/validators/catalog.yaml index 431160ebe..dfc9b8295 100644 --- a/recipes/validators/catalog.yaml +++ b/recipes/validators/catalog.yaml @@ -208,6 +208,13 @@ validators: timeout: 10m args: ["secure-accelerator-access"] env: [] + - name: slinky-slurm-health + phase: conformance + description: "Verify Slinky Slurm controller, node inventory, and job submission health" + image: ghcr.io/nvidia/aicr-validators/conformance:latest + timeout: 5m + args: ["slinky-slurm-health"] + env: [] - name: gpu-operator-health phase: conformance description: "Verify GPU operator health (conformance diagnostic)" diff --git a/validators/conformance/main.go b/validators/conformance/main.go index 99cf65cdf..e9ade7f02 100644 --- a/validators/conformance/main.go +++ b/validators/conformance/main.go @@ -37,6 +37,7 @@ func main() { "cluster-autoscaling": CheckClusterAutoscaling, "robust-controller": CheckRobustController, "secure-accelerator-access": CheckSecureAcceleratorAccess, + "slinky-slurm-health": CheckSlinkySlurmHealth, "gpu-operator-health": CheckGPUOperatorHealth, "platform-health": CheckPlatformHealth, }) diff --git a/validators/conformance/pod_exec.go b/validators/conformance/pod_exec.go new file mode 100644 index 000000000..6347289a1 --- /dev/null +++ b/validators/conformance/pod_exec.go @@ -0,0 +1,100 @@ +// Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "context" + stderrors "errors" + "fmt" + "net/http" + "net/url" + + "github.com/NVIDIA/aicr/pkg/errors" + "github.com/NVIDIA/aicr/validators" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/remotecommand" + k8sexec "k8s.io/client-go/util/exec" +) + +type podExecResult struct { + Stdout string + Stderr string + ExitCode int +} + +type podExecFunc func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) + +type podExecExecutorFactory func(*rest.Config, string, string) (remotecommand.Executor, error) + +var newPodExecExecutor podExecExecutorFactory = func(config *rest.Config, method, requestURL string) (remotecommand.Executor, error) { + parsedURL, err := url.Parse(requestURL) + if err != nil { + return nil, err + } + return remotecommand.NewSPDYExecutor(config, method, parsedURL) +} + +func execPodCommand(streamCtx context.Context, ctx *validators.Context, namespace, podName string, command []string) (podExecResult, error) { + pod, err := ctx.Clientset.CoreV1().Pods(namespace).Get(streamCtx, podName, metav1.GetOptions{}) + if err != nil { + return podExecResult{}, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to get pod %s/%s before exec", namespace, podName), err) + } + if len(pod.Spec.Containers) == 0 { + return podExecResult{}, errors.New(errors.ErrCodeInternal, + fmt.Sprintf("pod %s/%s has no containers", namespace, podName)) + } + + req := ctx.Clientset.CoreV1().RESTClient().Post(). + Resource("pods"). + Name(podName). + Namespace(namespace). + SubResource("exec"). + VersionedParams(&corev1.PodExecOptions{ + Container: pod.Spec.Containers[0].Name, + Command: command, + Stdout: true, + Stderr: true, + }, scheme.ParameterCodec) + + executor, err := newPodExecExecutor(ctx.RESTConfig, http.MethodPost, req.URL().String()) + if err != nil { + return podExecResult{}, errors.Wrap(errors.ErrCodeInternal, "failed to create pod exec executor", err) + } + + var stdout, stderr bytes.Buffer + streamErr := executor.StreamWithContext(streamCtx, remotecommand.StreamOptions{ + Stdout: &stdout, + Stderr: &stderr, + }) + result := podExecResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + } + if streamErr == nil { + return result, nil + } + + var exitErr k8sexec.ExitError + if stderrors.As(streamErr, &exitErr) { + result.ExitCode = exitErr.ExitStatus() + return result, nil + } + return result, streamErr +} diff --git a/validators/conformance/pod_exec_test.go b/validators/conformance/pod_exec_test.go new file mode 100644 index 000000000..b35cdc573 --- /dev/null +++ b/validators/conformance/pod_exec_test.go @@ -0,0 +1,216 @@ +// Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/NVIDIA/aicr/pkg/errors" + "github.com/NVIDIA/aicr/validators" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + k8sfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/remotecommand" +) + +func TestExecPodCommandBuildsExecRequestAndStreamsOutput(t *testing.T) { + ctx := podExecHTTPContext(t, corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "login-0", Namespace: "slurm"}, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "login"}, + {Name: "sidecar"}, + }, + }, + }) + + var gotMethod string + var gotURL string + restore := replacePodExecExecutorForTest(func(_ *rest.Config, method string, url string) (remotecommand.Executor, error) { + gotMethod = method + gotURL = url + return fakePodExecutor{ + stream: func(_ context.Context, opts remotecommand.StreamOptions) error { + if _, err := opts.Stdout.Write([]byte("login-0\n")); err != nil { + t.Fatalf("write stdout: %v", err) + } + if _, err := opts.Stderr.Write([]byte("warning\n")); err != nil { + t.Fatalf("write stderr: %v", err) + } + return nil + }, + }, nil + }) + defer restore() + + result, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"srun", "hostname"}) + if err != nil { + t.Fatalf("execPodCommand() error = %v", err) + } + if gotMethod != http.MethodPost { + t.Fatalf("method = %s, want POST", gotMethod) + } + for _, want := range []string{ + "/api/v1/namespaces/slurm/pods/login-0/exec", + "container=login", + "command=srun", + "command=hostname", + "stdout=true", + "stderr=true", + } { + if !strings.Contains(gotURL, want) { + t.Fatalf("exec URL = %s, want containing %q", gotURL, want) + } + } + if result.Stdout != "login-0\n" { + t.Fatalf("stdout = %q, want login hostname", result.Stdout) + } + if result.Stderr != "warning\n" { + t.Fatalf("stderr = %q, want warning", result.Stderr) + } + if result.ExitCode != 0 { + t.Fatalf("exit code = %d, want 0", result.ExitCode) + } +} + +func TestExecPodCommandReturnsPreStreamErrors(t *testing.T) { + tests := []struct { + name string + ctx *validators.Context + wantErr string + }{ + { + name: "missing pod", + ctx: &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(), + RESTConfig: &rest.Config{Host: "https://example.test"}, + }, + wantErr: "failed to get pod slurm/missing before exec", + }, + { + name: "no containers", + ctx: &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(&corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "missing", Namespace: "slurm"}, + }), + RESTConfig: &rest.Config{Host: "https://example.test"}, + }, + wantErr: "has no containers", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := execPodCommand(context.Background(), tt.ctx, "slurm", "missing", []string{"hostname"}) + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("error = %v, want containing %q", err, tt.wantErr) + } + }) + } +} + +func TestExecPodCommandReturnsExecutorFactoryError(t *testing.T) { + ctx := podExecHTTPContext(t, corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "login-0", Namespace: "slurm"}, + Spec: corev1.PodSpec{Containers: []corev1.Container{{Name: "login"}}}, + }) + + errBoom := errors.New(errors.ErrCodeInternal, "factory failed") + restore := replacePodExecExecutorForTest(func(*rest.Config, string, string) (remotecommand.Executor, error) { + return nil, errBoom + }) + defer restore() + + _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}) + if err == nil || !strings.Contains(err.Error(), "failed to create pod exec executor") || !strings.Contains(err.Error(), "factory failed") { + t.Fatalf("error = %v, want wrapped factory failure", err) + } +} + +func TestExecPodCommandReturnsStreamError(t *testing.T) { + ctx := podExecHTTPContext(t, corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "login-0", Namespace: "slurm"}, + Spec: corev1.PodSpec{Containers: []corev1.Container{{Name: "login"}}}, + }) + + errBoom := errors.New(errors.ErrCodeInternal, "stream failed") + restore := replacePodExecExecutorForTest(func(*rest.Config, string, string) (remotecommand.Executor, error) { + return fakePodExecutor{ + stream: func(context.Context, remotecommand.StreamOptions) error { + return errBoom + }, + }, nil + }) + defer restore() + + _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}) + if err == nil || !strings.Contains(err.Error(), "stream failed") { + t.Fatalf("error = %v, want stream failure", err) + } +} + +func podExecHTTPContext(t *testing.T, pod corev1.Pod) *validators.Context { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wantPath := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s", pod.Namespace, pod.Name) + if r.URL.Path != wantPath { + t.Fatalf("path = %s, want %s", r.URL.Path, wantPath) + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(pod); err != nil { + t.Fatalf("encode pod: %v", err) + } + })) + t.Cleanup(server.Close) + + clientset, err := kubernetes.NewForConfig(&rest.Config{Host: server.URL}) + if err != nil { + t.Fatalf("create clientset: %v", err) + } + return &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + RESTConfig: &rest.Config{Host: server.URL}, + } +} + +type fakePodExecutor struct { + stream func(context.Context, remotecommand.StreamOptions) error +} + +func (f fakePodExecutor) Stream(remotecommand.StreamOptions) error { + return nil +} + +func (f fakePodExecutor) StreamWithContext(ctx context.Context, opts remotecommand.StreamOptions) error { + return f.stream(ctx, opts) +} + +func replacePodExecExecutorForTest(fn podExecExecutorFactory) func() { + old := newPodExecExecutor + newPodExecExecutor = fn + return func() { newPodExecExecutor = old } +} diff --git a/validators/conformance/slinky_slurm_health_check.go b/validators/conformance/slinky_slurm_health_check.go new file mode 100644 index 000000000..cca78ac83 --- /dev/null +++ b/validators/conformance/slinky_slurm_health_check.go @@ -0,0 +1,375 @@ +// Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "strings" + + "github.com/NVIDIA/aicr/pkg/errors" + "github.com/NVIDIA/aicr/validators" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +const ( + slinkySlurmComponent = "slinky-slurm" + slinkySlurmNamespace = "slurm" + kwokNodeAnnotation = "kwok.x-k8s.io/node" +) + +var ( + slinkyLoginSetGVR = schema.GroupVersionResource{ + Group: "slinky.slurm.net", + Version: "v1beta1", + Resource: "loginsets", + } + slinkyNodeSetGVR = schema.GroupVersionResource{ + Group: "slinky.slurm.net", + Version: "v1beta1", + Resource: "nodesets", + } +) + +type slinkySlurmHealthCommand struct { + label string + command []string + requireStdout bool +} + +// slinkySlurmSinfoIdleMixShell requires at least one idle or mixed Slurm node. +// grep -q exits 0 when sinfo prints data lines and 1 when inventory is empty. +const slinkySlurmSinfoIdleMixShell = "sinfo -h -Ne -t idle,mix | grep -q ." + +var slinkySlurmHealthCommands = []slinkySlurmHealthCommand{ + { + label: "scontrol ping", + command: []string{"scontrol", "ping"}, + requireStdout: true, + }, + { + label: "sinfo idle/mix", + command: []string{"/bin/sh", "-c", slinkySlurmSinfoIdleMixShell}, + requireStdout: false, + }, + { + label: "srun hostname", + command: []string{"srun", "--immediate=5", "--time=0:03", "hostname"}, + requireStdout: true, + }, +} + +var slinkyExecCommand podExecFunc = execPodCommand + +// CheckSlinkySlurmHealth validates that a Slinky-managed Slurm cluster is +// reachable from the login pod, has idle or mixed worker nodes, and can +// schedule a minimal job without queueing indefinitely. +func CheckSlinkySlurmHealth(ctx *validators.Context) error { + if ctx.Clientset == nil { + return errors.New(errors.ErrCodeInvalidRequest, "kubernetes client is not available") + } + if ctx.RESTConfig == nil { + return errors.New(errors.ErrCodeInvalidRequest, "RESTConfig is not available") + } + if ctx.ValidationInput == nil { + return errors.New(errors.ErrCodeInvalidRequest, "validation is not available") + } + if !recipeHasEnabledComponent(ctx, slinkySlurmComponent) { + return validators.Skip("slinky-slurm component not present in recipe") + } + + if err := discoverSlinkySetAPIs(ctx); err != nil { + return err + } + if err := skipIfAllNodeSetPodsAreKWOK(ctx); err != nil { + return err + } + + loginPod, err := findReadySlinkyLoginPod(ctx) + if err != nil { + return err + } + recordSlinkyInventories(ctx, loginPod) + + failures := runSlinkySlurmHealthCommands(ctx, loginPod.Name) + if len(failures) > 0 { + return errors.New(errors.ErrCodeInternal, + "Slinky Slurm health commands failed:\n"+strings.Join(failures, "\n")) + } + + return nil +} + +func recipeHasEnabledComponent(ctx *validators.Context, name string) bool { + if ctx.ValidationInput == nil { + return false + } + for _, ref := range ctx.ValidationInput.ComponentRefs { + if ref.Name == name && ref.IsEnabled() { + return true + } + } + return false +} + +func runSlinkySlurmHealthCommands(ctx *validators.Context, loginPodName string) []string { + var failures []string + for _, check := range slinkySlurmHealthCommands { + result, execErr := slinkyExecCommand(ctx.Ctx, ctx, slinkySlurmNamespace, loginPodName, check.command) + recordSlinkyExecResult(ctx, loginPodName, check, result, execErr) + if execErr != nil { + failures = append(failures, fmt.Sprintf("%s: exec failed: %v", check.label, execErr)) + continue + } + if result.ExitCode != 0 { + failures = append(failures, fmt.Sprintf("%s: exit code %d", check.label, result.ExitCode)) + continue + } + if check.requireStdout && strings.TrimSpace(result.Stdout) == "" { + failures = append(failures, fmt.Sprintf("%s: empty stdout", check.label)) + } + } + return failures +} + +func discoverSlinkySetAPIs(ctx *validators.Context) error { + resources, err := ctx.Clientset.Discovery().ServerResourcesForGroupVersion("slinky.slurm.net/v1beta1") + if err != nil { + if apierrors.IsNotFound(err) { + return validators.Skip("Slinky Slurm API not available") + } + return errors.Wrap(errors.ErrCodeInternal, "failed to discover Slinky Slurm API", err) + } + + found := map[string]bool{} + for _, resource := range resources.APIResources { + isLoginSet := resource.Name == slinkyLoginSetGVR.Resource && resource.Kind == "LoginSet" + isNodeSet := resource.Name == slinkyNodeSetGVR.Resource && resource.Kind == "NodeSet" + if isLoginSet || isNodeSet { + found[resource.Name] = true + } + } + if !found[slinkyLoginSetGVR.Resource] || !found[slinkyNodeSetGVR.Resource] { + return validators.Skip("Slinky Slurm LoginSet/NodeSet API not available") + } + return nil +} + +func skipIfAllNodeSetPodsAreKWOK(ctx *validators.Context) error { + pods, err := listSlinkyNodeSetPods(ctx) + if err != nil { + return err + } + if len(pods) == 0 { + return errors.New(errors.ErrCodeNotFound, "slinky-slurm selected but no NodeSet pods were found") + } + + var resolved, kwok int + for _, pod := range pods { + if pod.Spec.NodeName == "" { + continue + } + node, getErr := ctx.Clientset.CoreV1().Nodes().Get(ctx.Ctx, pod.Spec.NodeName, metav1.GetOptions{}) + if getErr != nil { + return errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to get node %s for NodeSet pod %s", pod.Spec.NodeName, pod.Name), getErr) + } + resolved++ + if _, ok := node.Annotations[kwokNodeAnnotation]; ok { + kwok++ + } + } + if resolved > 0 && kwok == resolved { + return validators.Skip("Slinky NodeSet pods are on KWOK nodes; skipping Slurm health validation") + } + return nil +} + +func listSlinkyNodeSetPods(ctx *validators.Context) ([]corev1.Pod, error) { + return listPodsForSlinkySetSelectors(ctx, slinkyNodeSetGVR, "NodeSet") +} + +func listPodsForSlinkySetSelectors( + ctx *validators.Context, + gvr schema.GroupVersionResource, + kind string, +) ([]corev1.Pod, error) { + + sets, err := listSlinkySetsForController(ctx, gvr, kind) + if err != nil { + return nil, err + } + + pods := []corev1.Pod{} + for _, set := range sets { + if _, parseErr := labels.Parse(set.selector); parseErr != nil { + return nil, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("invalid %s selector for %s/%s: %q", + kind, slinkySlurmNamespace, set.name, set.selector), parseErr) + } + podList, listErr := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{ + LabelSelector: set.selector, + }) + if listErr != nil { + return nil, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to list Slinky Slurm pods for %s/%s", kind, set.name), listErr) + } + pods = append(pods, podList.Items...) + } + return pods, nil +} + +func findReadySlinkyLoginPod(ctx *validators.Context) (*corev1.Pod, error) { + pods, err := listPodsForSlinkySetSelectors(ctx, slinkyLoginSetGVR, "LoginSet") + if err != nil { + return nil, err + } + + var summary strings.Builder + for _, pod := range pods { + fmt.Fprintf(&summary, "%s phase=%s ready=%t node=%s\n", + pod.Name, pod.Status.Phase, podIsReady(&pod), valueOrUnknown(pod.Spec.NodeName)) + if pod.Status.Phase == corev1.PodRunning && podIsReady(&pod) { + return &pod, nil + } + } + return nil, errors.New(errors.ErrCodeNotFound, + fmt.Sprintf("no ready login pod found for Slinky LoginSet selectors in %s:\n%s", + slinkySlurmNamespace, strings.TrimSpace(summary.String()))) +} + +type slinkySetSelection struct { + kind string + name string + selector string +} + +func listSlinkySetsForController( + ctx *validators.Context, + gvr schema.GroupVersionResource, + kind string, +) ([]slinkySetSelection, error) { + + dynClient, err := getDynamicClient(ctx) + if err != nil { + return nil, err + } + list, err := dynClient.Resource(gvr).Namespace(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + return nil, validators.Skip(fmt.Sprintf("Slinky Slurm %s API not available", kind)) + } + return nil, errors.Wrap(errors.ErrCodeInternal, fmt.Sprintf("failed to list Slinky Slurm %ss", kind), err) + } + + selected := make([]slinkySetSelection, 0, len(list.Items)) + for i := range list.Items { + item := &list.Items[i] + if item.GetAPIVersion() != "slinky.slurm.net/v1beta1" || item.GetKind() != kind { + continue + } + controllerName, _, _ := unstructured.NestedString(item.Object, "spec", "controllerRef", "name") + controllerNamespace, _, _ := unstructured.NestedString(item.Object, "spec", "controllerRef", "namespace") + if controllerName != slinkySlurmComponent { + continue + } + if controllerNamespace != "" && controllerNamespace != slinkySlurmNamespace { + continue + } + selector, found, selectorErr := unstructured.NestedString(item.Object, "status", "selector") + if selectorErr != nil { + return nil, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to read selector from %s/%s", kind, item.GetName()), selectorErr) + } + if !found || strings.TrimSpace(selector) == "" { + return nil, errors.New(errors.ErrCodeNotFound, + fmt.Sprintf("Slinky Slurm %s %s/%s has no status.selector", + kind, item.GetNamespace(), item.GetName())) + } + selected = append(selected, slinkySetSelection{ + kind: kind, + name: item.GetName(), + selector: selector, + }) + } + if len(selected) == 0 { + return nil, errors.New(errors.ErrCodeNotFound, + fmt.Sprintf("no Slinky Slurm %s found for controllerRef.name=%s", kind, slinkySlurmComponent)) + } + return selected, nil +} + +func podIsReady(pod *corev1.Pod) bool { + for _, condition := range pod.Status.Conditions { + if condition.Type == corev1.PodReady && condition.Status == corev1.ConditionTrue { + return true + } + } + return false +} + +func recordSlinkyInventories(ctx *validators.Context, loginPod *corev1.Pod) { + slurmPods, slurmPodsErr := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{}) + if slurmPodsErr != nil { + recordRawTextArtifact(ctx, "Slinky Slurm pods", "kubectl get pods -n slurm -o wide", + fmt.Sprintf("failed to list pods: %v", slurmPodsErr)) + } else { + var podSummary strings.Builder + for _, pod := range slurmPods.Items { + fmt.Fprintf(&podSummary, "%-48s ready=%s phase=%s node=%s\n", + pod.Name, podReadyCount(pod), pod.Status.Phase, valueOrUnknown(pod.Spec.NodeName)) + } + recordRawTextArtifact(ctx, "Slinky Slurm pods", "kubectl get pods -n slurm -o wide", podSummary.String()) + } + + nodeSetPods, nodeSetErr := listSlinkyNodeSetPods(ctx) + if nodeSetErr != nil { + recordRawTextArtifact(ctx, "Slinky Slurm NodeSet pods", "kubectl get pods -n slurm", + fmt.Sprintf("failed to list NodeSet pods: %v", nodeSetErr)) + } else { + var nodeSetSummary strings.Builder + for _, pod := range nodeSetPods { + fmt.Fprintf(&nodeSetSummary, "%-48s ready=%s phase=%s node=%s\n", + pod.Name, podReadyCount(pod), pod.Status.Phase, valueOrUnknown(pod.Spec.NodeName)) + } + recordRawTextArtifact(ctx, "Slinky Slurm NodeSet pods", + "kubectl -n slurm get nodesets -o json | jq -r '.items[] | select(.apiVersion == \"slinky.slurm.net/v1beta1\") | .status.selector'", + nodeSetSummary.String()) + } + + recordRawTextArtifact(ctx, "Selected Slinky Slurm login pod", "", + fmt.Sprintf("Name: %s/%s\nReady: %t\nNode: %s", + loginPod.Namespace, loginPod.Name, podIsReady(loginPod), valueOrUnknown(loginPod.Spec.NodeName))) +} + +func recordSlinkyExecResult(ctx *validators.Context, podName string, check slinkySlurmHealthCommand, result podExecResult, execErr error) { + var body strings.Builder + fmt.Fprintf(&body, "Pod: %s/%s\n", slinkySlurmNamespace, podName) + fmt.Fprintf(&body, "Command: %s\n", strings.Join(check.command, " ")) + fmt.Fprintf(&body, "ExitCode: %d\n", result.ExitCode) + if execErr != nil { + fmt.Fprintf(&body, "Error: %v\n", execErr) + } + fmt.Fprintf(&body, "\nstdout:\n%s\n", result.Stdout) + fmt.Fprintf(&body, "\nstderr:\n%s\n", result.Stderr) + + recordRawTextArtifact(ctx, fmt.Sprintf("Slinky Slurm %s result", check.label), + fmt.Sprintf("kubectl exec -n slurm %s -- %s", podName, strings.Join(check.command, " ")), + body.String()) +} diff --git a/validators/conformance/slinky_slurm_health_check_test.go b/validators/conformance/slinky_slurm_health_check_test.go new file mode 100644 index 000000000..c8f8ade3b --- /dev/null +++ b/validators/conformance/slinky_slurm_health_check_test.go @@ -0,0 +1,465 @@ +// Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "strings" + "testing" + + "github.com/NVIDIA/aicr/pkg/errors" + "github.com/NVIDIA/aicr/pkg/recipe" + v1 "github.com/NVIDIA/aicr/pkg/validator/v1" + "github.com/NVIDIA/aicr/validators" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/discovery/fake" + dynamicfake "k8s.io/client-go/dynamic/fake" + "k8s.io/client-go/kubernetes" + k8sfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/rest" +) + +var ( + testSlinkyLoginSetGVR = schema.GroupVersionResource{ + Group: "slinky.slurm.net", + Version: "v1beta1", + Resource: "loginsets", + } + testSlinkyNodeSetGVR = schema.GroupVersionResource{ + Group: "slinky.slurm.net", + Version: "v1beta1", + Resource: "nodesets", + } +) + +func TestCheckSlinkySlurmHealthSkipsWithoutSlinkyComponent(t *testing.T) { + ctx := &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: "gpu-operator"}}, + }, + } + + err := CheckSlinkySlurmHealth(ctx) + if !isSkipLike(err, "slinky-slurm") { + t.Fatalf("error = %v, want skip mentioning slinky-slurm", err) + } +} + +func TestCheckSlinkySlurmHealthRequiresContext(t *testing.T) { + tests := []struct { + name string + ctx *validators.Context + want string + }{ + { + name: "missing client", + ctx: &validators.Context{ + Ctx: context.Background(), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + }, + want: "kubernetes client", + }, + { + name: "missing rest config", + ctx: &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(), + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + }, + want: "RESTConfig", + }, + { + name: "missing validation", + ctx: &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(), + RESTConfig: &rest.Config{Host: "https://example.test"}, + }, + want: "validation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSlinkySlurmHealth(tt.ctx) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("error = %v, want containing %q", err, tt.want) + } + }) + } +} + +func TestCheckSlinkySlurmHealthSkipsWhenSlinkyAPIUnavailable(t *testing.T) { + ctx := &validators.Context{ + Ctx: context.Background(), + Clientset: k8sfake.NewSimpleClientset(), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + } + + err := CheckSlinkySlurmHealth(ctx) + if !isSkipLike(err, "Slinky Slurm API") { + t.Fatalf("error = %v, want skip mentioning Slinky Slurm API", err) + } +} + +func TestCheckSlinkySlurmHealthExecOutcomes(t *testing.T) { + errBoom := errors.New(errors.ErrCodeInternal, "exec failed") + tests := []struct { + name string + result podExecResult + err error + wantErr string + }{ + {name: "success", result: podExecResult{Stdout: "slinky-0\n", ExitCode: 0}}, + {name: "empty stdout", result: podExecResult{Stdout: "\n", ExitCode: 0}, wantErr: "empty stdout"}, + {name: "nonzero", result: podExecResult{Stderr: "srun failed", ExitCode: 1}, wantErr: "exit code 1"}, + {name: "exec error", err: errBoom, wantErr: "exec failed"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + restore := replaceSlinkyExecForTest(func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) { + return tt.result, tt.err + }) + defer restore() + + err := CheckSlinkySlurmHealth(slurmReadyTestContext(t, false)) + if tt.wantErr == "" && err != nil { + t.Fatalf("error = %v, want nil", err) + } + if tt.wantErr != "" && (err == nil || !strings.Contains(err.Error(), tt.wantErr)) { + t.Fatalf("error = %v, want containing %q", err, tt.wantErr) + } + }) + } +} + +func TestCheckSlinkySlurmHealthRunsAllHealthCommands(t *testing.T) { + var gotCommands []string + restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, _ string, command []string) (podExecResult, error) { + gotCommands = append(gotCommands, strings.Join(command, " ")) + return podExecResult{Stdout: strings.Join(command, " ") + "\n"}, nil + }) + defer restore() + + err := CheckSlinkySlurmHealth(slurmReadyTestContext(t, false)) + if err != nil { + t.Fatalf("error = %v, want nil", err) + } + + wantCommands := []string{ + "scontrol ping", + "/bin/sh -c " + slinkySlurmSinfoIdleMixShell, + "srun --immediate=5 --time=0:01 hostname", + } + if strings.Join(gotCommands, ",") != strings.Join(wantCommands, ",") { + t.Fatalf("commands = %v, want %v", gotCommands, wantCommands) + } +} + +func TestCheckSlinkySlurmHealthDiscoversPodsFromSlinkyCRSelectors(t *testing.T) { + var gotPodName string + restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, podName string, _ []string) (podExecResult, error) { + gotPodName = podName + return podExecResult{Stdout: "ok\n"}, nil + }) + defer restore() + + ctx := slurmCustomCRSelectorContext(t, false) + err := CheckSlinkySlurmHealth(ctx) + if err != nil { + t.Fatalf("error = %v, want nil", err) + } + if gotPodName != "custom-login-pod" { + t.Fatalf("exec pod = %q, want custom-login-pod", gotPodName) + } +} + +func TestCheckSlinkySlurmHealthCollectsAllCommandFailures(t *testing.T) { + restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, _ string, command []string) (podExecResult, error) { + joined := strings.Join(command, " ") + if strings.Contains(joined, "sinfo -h -Ne -t idle,mix") { + return podExecResult{Stderr: "down", ExitCode: 1}, nil + } + return podExecResult{Stdout: "\n"}, nil + }) + defer restore() + + err := CheckSlinkySlurmHealth(slurmReadyTestContext(t, false)) + if err == nil { + t.Fatal("error = nil, want combined health failure") + } + for _, want := range []string{ + "scontrol ping: empty stdout", + "sinfo idle/mix: exit code 1", + "srun hostname: empty stdout", + } { + if !strings.Contains(err.Error(), want) { + t.Fatalf("error = %v, want containing %q", err, want) + } + } +} + +func slurmCustomCRSelectorContext(t *testing.T, kwok bool) *validators.Context { + t.Helper() + + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-0"}} + if kwok { + node.Annotations = map[string]string{kwokNodeAnnotation: "fake"} + } + + clientset := k8sfake.NewSimpleClientset( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: slinkySlurmNamespace}}, + node, + readyCustomLoginPod(), + readyCustomNodeSetPod(), + ) + addSlinkyDiscovery(t, clientset) + + return &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + DynamicClient: newSlinkyDynamicClient(t, customLoginSet(), customNodeSet()), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + } +} + +func TestCheckSlinkySlurmHealthSkipsWhenAllNodeSetPodsAreOnKWOKNodes(t *testing.T) { + restore := replaceSlinkyExecForTest(func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) { + t.Fatal("exec should not run when all NodeSet pods are on KWOK nodes") + return podExecResult{}, nil + }) + defer restore() + + err := CheckSlinkySlurmHealth(slurmReadyTestContext(t, true)) + if !isSkipLike(err, "KWOK") { + t.Fatalf("error = %v, want KWOK skip", err) + } +} + +func TestCheckSlinkySlurmHealthFailsWithoutReadyLoginPod(t *testing.T) { + ctx := slurmReadyTestContext(t, false) + err := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).Delete(ctx.Ctx, "slinky-login-0", metav1.DeleteOptions{}) + if err != nil { + t.Fatalf("delete login pod: %v", err) + } + + err = CheckSlinkySlurmHealth(ctx) + if err == nil || !strings.Contains(err.Error(), "ready login pod") { + t.Fatalf("error = %v, want ready login pod failure", err) + } +} + +func slurmReadyTestContext(t *testing.T, kwok bool) *validators.Context { + t.Helper() + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-node-0"}, + } + if kwok { + node.Annotations = map[string]string{kwokNodeAnnotation: "fake"} + } + + clientset := k8sfake.NewSimpleClientset( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: slinkySlurmNamespace}}, + node, + readyLoginPod(), + readyNodeSetPod(), + ) + addSlinkyDiscovery(t, clientset) + + return &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + DynamicClient: newSlinkyDynamicClient(t, defaultLoginSet(), defaultNodeSet()), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + } +} + +func readyLoginPod() *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "slinky-login-0", + Namespace: slinkySlurmNamespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "slurm-login", + }, + }, + Spec: corev1.PodSpec{ + NodeName: "worker-node-0", + Containers: []corev1.Container{{Name: "login", Image: "slinky-login:test"}}, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + Conditions: []corev1.PodCondition{{ + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }}, + ContainerStatuses: []corev1.ContainerStatus{{Name: "login", Ready: true}}, + }, + } +} + +func readyNodeSetPod() *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "slinky-nodeset-0", + Namespace: slinkySlurmNamespace, + Labels: map[string]string{ + "app.kubernetes.io/name": "slurm-nodeset", + }, + }, + Spec: corev1.PodSpec{ + NodeName: "worker-node-0", + Containers: []corev1.Container{{Name: "slurmd", Image: "slinky-slurmd:test"}}, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + Conditions: []corev1.PodCondition{{ + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }}, + ContainerStatuses: []corev1.ContainerStatus{{Name: "slurmd", Ready: true}}, + }, + } +} + +func readyCustomLoginPod() *corev1.Pod { + pod := readyLoginPod() + pod.Name = "custom-login-pod" + pod.Labels = map[string]string{ + "app.kubernetes.io/name": "login", + "app.kubernetes.io/instance": "custom-login", + } + return pod +} + +func readyCustomNodeSetPod() *corev1.Pod { + pod := readyNodeSetPod() + pod.Name = "custom-worker-0" + pod.Labels = map[string]string{ + "app.kubernetes.io/name": "slurmd", + "app.kubernetes.io/instance": "custom-worker", + } + return pod +} + +func defaultLoginSet() *unstructured.Unstructured { + return slinkySetObject("LoginSet", "slinky-slurm-login-slinky", "app.kubernetes.io/name=slurm-login") +} + +func defaultNodeSet() *unstructured.Unstructured { + return slinkySetObject("NodeSet", "slinky-slurm-worker-slinky", "app.kubernetes.io/name=slurm-nodeset") +} + +func customLoginSet() *unstructured.Unstructured { + return slinkySetObject("LoginSet", "custom-login", "app.kubernetes.io/instance=custom-login,app.kubernetes.io/name=login") +} + +func customNodeSet() *unstructured.Unstructured { + return slinkySetObject("NodeSet", "custom-worker", "app.kubernetes.io/instance=custom-worker,app.kubernetes.io/name=slurmd") +} + +func slinkySetObject(kind, name, selector string) *unstructured.Unstructured { + obj := &unstructured.Unstructured{ + Object: map[string]any{ + "apiVersion": "slinky.slurm.net/v1beta1", + "kind": kind, + "metadata": map[string]any{ + "name": name, + "namespace": slinkySlurmNamespace, + }, + "spec": map[string]any{ + "controllerRef": map[string]any{ + "name": slinkySlurmComponent, + "namespace": slinkySlurmNamespace, + }, + }, + "status": map[string]any{ + "selector": selector, + }, + }, + } + obj.SetGroupVersionKind(schema.GroupVersionKind{ + Group: "slinky.slurm.net", + Version: "v1beta1", + Kind: kind, + }) + return obj +} + +func newSlinkyDynamicClient(t *testing.T, objects ...runtime.Object) *dynamicfake.FakeDynamicClient { + t.Helper() + scheme := runtime.NewScheme() + return dynamicfake.NewSimpleDynamicClientWithCustomListKinds(scheme, map[schema.GroupVersionResource]string{ + testSlinkyLoginSetGVR: "LoginSetList", + testSlinkyNodeSetGVR: "NodeSetList", + }, objects...) +} + +func addSlinkyDiscovery(t *testing.T, clientset kubernetes.Interface) { + t.Helper() + discovery, ok := clientset.Discovery().(*fake.FakeDiscovery) + if !ok { + t.Fatalf("discovery client = %T, want *fake.FakeDiscovery", clientset.Discovery()) + } + discovery.Resources = []*metav1.APIResourceList{{ + GroupVersion: "slinky.slurm.net/v1beta1", + APIResources: []metav1.APIResource{ + { + Name: "loginsets", + Kind: "LoginSet", + Namespaced: true, + }, + { + Name: "nodesets", + Kind: "NodeSet", + Namespaced: true, + }, + }, + }} +} + +func isSkipLike(err error, want string) bool { + return err != nil && + (strings.Contains(err.Error(), want) || strings.Contains(strings.ToLower(err.Error()), strings.ToLower(want))) +} + +func replaceSlinkyExecForTest(fn podExecFunc) func() { + old := slinkyExecCommand + slinkyExecCommand = fn + return func() { slinkyExecCommand = old } +} diff --git a/vendor/github.com/gorilla/websocket/.gitignore b/vendor/github.com/gorilla/websocket/.gitignore new file mode 100644 index 000000000..cd3fcd1ef --- /dev/null +++ b/vendor/github.com/gorilla/websocket/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +.idea/ +*.iml diff --git a/vendor/github.com/gorilla/websocket/AUTHORS b/vendor/github.com/gorilla/websocket/AUTHORS new file mode 100644 index 000000000..1931f4006 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/AUTHORS @@ -0,0 +1,9 @@ +# This is the official list of Gorilla WebSocket authors for copyright +# purposes. +# +# Please keep the list sorted. + +Gary Burd +Google LLC (https://opensource.google.com/) +Joachim Bauch + diff --git a/vendor/github.com/gorilla/websocket/LICENSE b/vendor/github.com/gorilla/websocket/LICENSE new file mode 100644 index 000000000..9171c9722 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/gorilla/websocket/README.md b/vendor/github.com/gorilla/websocket/README.md new file mode 100644 index 000000000..ff8bfab0b --- /dev/null +++ b/vendor/github.com/gorilla/websocket/README.md @@ -0,0 +1,32 @@ +# Gorilla WebSocket + +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + +Gorilla WebSocket is a [Go](http://golang.org/) implementation of the +[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. + + +### Documentation + +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) +* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo) +* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch) + +### Status + +The Gorilla WebSocket package provides a complete and tested implementation of +the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The +package API is stable. + +### Installation + + go get github.com/gorilla/websocket + +### Protocol Compliance + +The Gorilla WebSocket package passes the server tests in the [Autobahn Test +Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn +subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn). diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go new file mode 100644 index 000000000..00917ea34 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/client.go @@ -0,0 +1,517 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strings" + "time" +) + +// ErrBadHandshake is returned when the server response to opening handshake is +// invalid. +var ErrBadHandshake = errors.New("websocket: bad handshake") + +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + +// NewClient creates a new client connection using the given net connection. +// The URL u specifies the host and request URI. Use requestHeader to specify +// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies +// (Cookie). Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etc. +// +// Deprecated: Use Dialer instead. +func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, + } + return d.Dial(u.String(), requestHeader) +} + +// A Dialer contains options for connecting to WebSocket server. +// +// It is safe to call Dialer's methods concurrently. +type Dialer struct { + // The following custom dial functions can be set to establish + // connections to either the backend server or the proxy (if it + // exists). The scheme of the dialed entity (either backend or + // proxy) determines which custom dial function is selected: + // either NetDialTLSContext for HTTPS or NetDialContext/NetDial + // for HTTP. Since the "Proxy" function can determine the scheme + // dynamically, it can make sense to set multiple custom dial + // functions simultaneously. + // + // NetDial specifies the dial function for creating TCP connections. If + // NetDial is nil, net.Dialer DialContext is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. + NetDial func(network, addr string) (net.Conn, error) + + // NetDialContext specifies the dial function for creating TCP connections. If + // NetDialContext is nil, NetDial is used. + // If "Proxy" field is also set, this function dials the proxy--not + // the backend server. + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + // If "Proxy" field is also set, this function dials the proxy (and performs + // the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy + // dialing case the TLSClientConfig could still be necessary for TLS to the backend server. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with tls.Client. + // If nil, the default configuration is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. + TLSClientConfig *tls.Config + + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then a useful default size is used. The I/O buffer sizes + // do not limit the size of the messages that can be sent or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the client's requested subprotocols. + Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool + + // Jar specifies the cookie jar. + // If Jar is nil, cookies are not sent in requests and ignored + // in responses. + Jar http.CookieJar +} + +// Dial creates a new client connection by calling DialContext with a background context. +func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + return d.DialContext(context.Background(), urlStr, requestHeader) +} + +var errMalformedURL = errors.New("malformed ws or wss URL") + +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + switch u.Scheme { + case "wss": + hostPort += ":443" + case "https": + hostPort += ":443" + default: + hostPort += ":80" + } + } + return hostPort, hostNoPort +} + +// DefaultDialer is a dialer with all fields set to the default values. +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, +} + +// nilDialer is dialer to use when receiver is nil. +var nilDialer = *DefaultDialer + +// DialContext creates a new client connection. Use requestHeader to specify the +// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). +// Use the response.Header to get the selected subprotocol +// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). +// +// The context will be used in the request and in the Dialer. +// +// If the WebSocket handshake fails, ErrBadHandshake is returned along with a +// non-nil *http.Response so that callers can handle redirects, authentication, +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + if d == nil { + d = &nilDialer + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + + u, err := url.Parse(urlStr) + if err != nil { + return nil, nil, err + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + req = req.WithContext(ctx) + + // Set the cookies present in the cookie jar of the dialer + if d.Jar != nil { + for _, cookie := range d.Jar.Cookies(u) { + req.AddCookie(cookie) + } + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + case k == "Sec-Websocket-Protocol": + req.Header["Sec-WebSocket-Protocol"] = vs + default: + req.Header[k] = vs + } + } + + if d.EnableCompression { + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} + } + + if d.HandshakeTimeout != 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) + defer cancel() + } + + var proxyURL *url.URL + if d.Proxy != nil { + proxyURL, err = d.Proxy(req) + if err != nil { + return nil, nil, err + } + } + netDial, err := d.netDialFn(ctx, proxyURL, u) + if err != nil { + return nil, nil, err + } + + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + + netConn, err := netDial(ctx, "tcp", hostPort) + if err != nil { + return nil, nil, err + } + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: netConn, + }) + } + + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + + // Do TLS handshake over established connection if a proxy exists. + if proxyURL != nil && u.Scheme == "https" { + + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn + + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + + if err != nil { + return nil, nil, err + } + } + + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) + + if err := req.Write(netConn); err != nil { + return nil, nil, err + } + + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } + + resp, err := http.ReadResponse(conn.br, req) + if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } + return nil, nil, err + } + + if d.Jar != nil { + if rc := resp.Cookies(); len(rc) > 0 { + d.Jar.SetCookies(u, rc) + } + } + + if resp.StatusCode != 101 || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + for _, ext := range parseExtensions(resp.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + + resp.Body = io.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, resp, err + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + + return conn, resp, nil +} + +// Returns the dial function to establish the connection to either the backend +// server or the proxy (if it exists). If the dialed entity is HTTPS, then the +// returned dial function *also* performs the TLS handshake to the dialed entity. +// NOTE: If a proxy exists, it is possible for a second TLS handshake to be +// necessary over the established connection. +func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *url.URL) (netDialerFunc, error) { + var netDial netDialerFunc + if proxyURL != nil { + netDial = d.netDialFromURL(proxyURL) + } else { + netDial = d.netDialFromURL(backendURL) + } + // If needed, wrap the dial function to set the connection deadline. + if deadline, ok := ctx.Deadline(); ok { + netDial = netDialWithDeadline(netDial, deadline) + } + // Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth. + if proxyURL != nil { + return proxyFromURL(proxyURL, netDial) + } + return netDial, nil +} + +// Returns function to create the connection depending on the Dialer's +// custom dialing functions and the passed URL of entity connecting to. +func (d *Dialer) netDialFromURL(u *url.URL) netDialerFunc { + var netDial netDialerFunc + switch { + case d.NetDialContext != nil: + netDial = d.NetDialContext + case d.NetDial != nil: + netDial = func(ctx context.Context, net, addr string) (net.Conn, error) { + return d.NetDial(net, addr) + } + default: + netDial = (&net.Dialer{}).DialContext + } + // If dialed entity is HTTPS, then either use custom TLS dialing function (if exists) + // or wrap the previously computed "netDial" to use TLS config for handshake. + if u.Scheme == "https" { + if d.NetDialTLSContext != nil { + netDial = d.NetDialTLSContext + } else { + netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, u) + } + } + return netDial +} + +// Returns wrapped "netDial" function, performing TLS handshake after connecting. +func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc { + return func(ctx context.Context, unused, addr string) (net.Conn, error) { + hostPort, hostNoPort := hostPortNoPort(u) + trace := httptrace.ContextClientTrace(ctx) + if trace != nil && trace.GetConn != nil { + trace.GetConn(hostPort) + } + // Creates TCP connection to addr using passed "netDial" function. + conn, err := netDial(ctx, "tcp", addr) + if err != nil { + return nil, err + } + cfg := cloneTLSConfig(tlsConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(conn, cfg) + // Do the TLS handshake using TLSConfig over the wrapped connection. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err = doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } + if err != nil { + tlsConn.Close() + return nil, err + } + return tlsConn, nil + } +} + +// Returns wrapped "netDial" function, setting passed deadline. +func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := netDial(ctx, network, addr) + if err != nil { + return nil, err + } + err = c.SetDeadline(deadline) + if err != nil { + c.Close() + return nil, err + } + return c, nil + } +} + +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/gorilla/websocket/compression.go b/vendor/github.com/gorilla/websocket/compression.go new file mode 100644 index 000000000..fe1079edb --- /dev/null +++ b/vendor/github.com/gorilla/websocket/compression.go @@ -0,0 +1,152 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + mr := io.MultiReader(r, strings.NewReader(tail)) + if err := fr.(flate.Resetter).Reset(mr, nil); err != nil { + // Reset never fails, but handle error in case that changes. + fr = flate.NewReader(mr) + } + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go new file mode 100644 index 000000000..9562ffd49 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/conn.go @@ -0,0 +1,1246 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "crypto/rand" + "encoding/binary" + "errors" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents a close message. +type CloseError struct { + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +// maskRand is an io.Reader for generating mask bytes. The reader is initialized +// to crypto/rand Reader. Tests swap the reader to a math/rand reader for +// reproducible results. +var maskRand = rand.Reader + +// newMaskKey returns a new 32 bit value for masking client frames. +func newMaskKey() [4]byte { + var k [4]byte + _, _ = io.ReadFull(maskRand, k[:]) + return k +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { + + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) + } + + mu := make(chan struct{}, 1) + mu <- struct{}{} + c := &Conn{ + isServer: isServer, + br: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting +// for a close message. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + // Discard is guaranteed to succeed because the number of bytes to discard + // is less than or equal to the number of bytes buffered. + _, _ = c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + <-c.mu + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + _ = c.writeFatal(ErrCloseSent) + } + return nil +} + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + if deadline.IsZero() { + // No timeout for zero time. + <-c.mu + } else { + d := time.Until(deadline) + if d < 0 { + return errWriteTimeout + } + select { + case <-c.mu: + default: + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + } + } + + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + if err := c.conn.SetWriteDeadline(deadline); err != nil { + return c.writeFatal(err) + } + if _, err = c.conn.Write(buf); err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + _ = c.writeFatal(ErrCloseSent) + } + return err +} + +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return nil, err + } + c.writer = &mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) endMessage(err error) error { + if w.err != nil { + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.endMessage(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.endMessage(err) + } + + if final { + _ = w.endMessage(errWriteClosed) + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + return w.flushFrame(true, nil) +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return err + } + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 + mask := p[1]&maskBit != 0 + _ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0 + + c.readDecompress = false + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } + } + + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + errors = append(errors, "len > 125 for control") + } + if !final { + errors = append(errors, "FIN not set on control") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + errors = append(errors, "data before FIN") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + errors = append(errors, "continuation after FIN") + } + c.readFinal = final + default: + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) + } + + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } + } + + // 4. Handle frame masking. + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + + if c.readLimit > 0 && c.readLength > c.readLimit { + // Make a best effort to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + _ = c.setReadRemaining(0) // will not fail because argument is >= 0 + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + // Make a best effor to send a close message describing the problem. + _ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = err + break + } + + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = err + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + rem := c.readRemaining + rem -= int64(n) + _ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0 + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = err + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = io.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + // Make a best effor to send the close message. + _ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING message application data. The default +// ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + // Make a best effort to send the pong message. + _ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + return nil + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG message application data. The default +// pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +// Deprecated: Use the NetConn method. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. +func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/vendor/github.com/gorilla/websocket/doc.go b/vendor/github.com/gorilla/websocket/doc.go new file mode 100644 index 000000000..8db0cef95 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/doc.go @@ -0,0 +1,227 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements the WebSocket protocol defined in RFC 6455. +// +// Overview +// +// The Conn type represents a WebSocket connection. A server application calls +// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: +// +// var upgrader = websocket.Upgrader{ +// ReadBufferSize: 1024, +// WriteBufferSize: 1024, +// } +// +// func handler(w http.ResponseWriter, r *http.Request) { +// conn, err := upgrader.Upgrade(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// ... Use conn to send and receive messages. +// } +// +// Call the connection's WriteMessage and ReadMessage methods to send and +// receive messages as a slice of bytes. This snippet of code shows how to echo +// messages using these methods: +// +// for { +// messageType, p, err := conn.ReadMessage() +// if err != nil { +// log.Println(err) +// return +// } +// if err := conn.WriteMessage(messageType, p); err != nil { +// log.Println(err) +// return +// } +// } +// +// In above snippet of code, p is a []byte and messageType is an int with value +// websocket.BinaryMessage or websocket.TextMessage. +// +// An application can also send and receive messages using the io.WriteCloser +// and io.Reader interfaces. To send a message, call the connection NextWriter +// method to get an io.WriteCloser, write the message to the writer and close +// the writer when done. To receive a message, call the connection NextReader +// method to get an io.Reader and read until io.EOF is returned. This snippet +// shows how to echo messages using the NextWriter and NextReader methods: +// +// for { +// messageType, r, err := conn.NextReader() +// if err != nil { +// return +// } +// w, err := conn.NextWriter(messageType) +// if err != nil { +// return err +// } +// if _, err := io.Copy(w, r); err != nil { +// return err +// } +// if err := w.Close(); err != nil { +// return err +// } +// } +// +// Data Messages +// +// The WebSocket protocol distinguishes between text and binary data messages. +// Text messages are interpreted as UTF-8 encoded text. The interpretation of +// binary messages is left to the application. +// +// This package uses the TextMessage and BinaryMessage integer constants to +// identify the two data message types. The ReadMessage and NextReader methods +// return the type of the received message. The messageType argument to the +// WriteMessage and NextWriter methods specifies the type of a sent message. +// +// It is the application's responsibility to ensure that text messages are +// valid UTF-8 encoded text. +// +// Control Messages +// +// The WebSocket protocol defines three types of control messages: close, ping +// and pong. Call the connection WriteControl, WriteMessage or NextWriter +// methods to send a control message to the peer. +// +// Connections handle received close messages by calling the handler function +// set with the SetCloseHandler method and by returning a *CloseError from the +// NextReader, ReadMessage or the message Read method. The default close +// handler sends a close message to the peer. +// +// Connections handle received ping messages by calling the handler function +// set with the SetPingHandler method. The default ping handler sends a pong +// message to the peer. +// +// Connections handle received pong messages by calling the handler function +// set with the SetPongHandler method. The default pong handler does nothing. +// If an application sends ping messages, then the application should set a +// pong handler to receive the corresponding pong. +// +// The control message handler functions are called from the NextReader, +// ReadMessage and message reader Read methods. The default close and ping +// handlers can block these methods for a short time when the handler writes to +// the connection. +// +// The application must read the connection to process close, ping and pong +// messages sent from the peer. If the application is not otherwise interested +// in messages from the peer, then the application should start a goroutine to +// read and discard messages from the peer. A simple example is: +// +// func readLoop(c *websocket.Conn) { +// for { +// if _, _, err := c.NextReader(); err != nil { +// c.Close() +// break +// } +// } +// } +// +// Concurrency +// +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and +// that no more than one goroutine calls the read methods (NextReader, +// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) +// concurrently. +// +// The Close and WriteControl methods can be called concurrently with all other +// methods. +// +// Origin Considerations +// +// Web browsers allow Javascript applications to open a WebSocket connection to +// any host. It's up to the server to enforce an origin policy using the Origin +// request header sent by the browser. +// +// The Upgrader calls the function specified in the CheckOrigin field to check +// the origin. If the CheckOrigin function returns false, then the Upgrade +// method fails the WebSocket handshake with HTTP status 403. +// +// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail +// the handshake if the Origin request header is present and the Origin host is +// not equal to the Host request header. +// +// The deprecated package-level Upgrade function does not perform origin +// checking. The application is responsible for checking the Origin header +// before calling the Upgrade function. +// +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. +// +// Compression EXPERIMENTAL +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// } +// +// If compression was successfully negotiated with the connection's peer, any +// message received in compressed form will be automatically decompressed. +// All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(false) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. +package websocket diff --git a/vendor/github.com/gorilla/websocket/join.go b/vendor/github.com/gorilla/websocket/join.go new file mode 100644 index 000000000..c64f8c829 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/join.go @@ -0,0 +1,42 @@ +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "io" + "strings" +) + +// JoinMessages concatenates received messages to create a single io.Reader. +// The string term is appended to each message. The returned reader does not +// support concurrent calls to the Read method. +func JoinMessages(c *Conn, term string) io.Reader { + return &joinReader{c: c, term: term} +} + +type joinReader struct { + c *Conn + term string + r io.Reader +} + +func (r *joinReader) Read(p []byte) (int, error) { + if r.r == nil { + var err error + _, r.r, err = r.c.NextReader() + if err != nil { + return 0, err + } + if r.term != "" { + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) + } + } + n, err := r.r.Read(p) + if err == io.EOF { + err = nil + r.r = nil + } + return n, err +} diff --git a/vendor/github.com/gorilla/websocket/json.go b/vendor/github.com/gorilla/websocket/json.go new file mode 100644 index 000000000..dc2c1f641 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/json.go @@ -0,0 +1,60 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "encoding/json" + "io" +) + +// WriteJSON writes the JSON encoding of v as a message. +// +// Deprecated: Use c.WriteJSON instead. +func WriteJSON(c *Conn, v interface{}) error { + return c.WriteJSON(v) +} + +// WriteJSON writes the JSON encoding of v as a message. +// +// See the documentation for encoding/json Marshal for details about the +// conversion of Go values to JSON. +func (c *Conn) WriteJSON(v interface{}) error { + w, err := c.NextWriter(TextMessage) + if err != nil { + return err + } + err1 := json.NewEncoder(w).Encode(v) + err2 := w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// Deprecated: Use c.ReadJSON instead. +func ReadJSON(c *Conn, v interface{}) error { + return c.ReadJSON(v) +} + +// ReadJSON reads the next JSON-encoded message from the connection and stores +// it in the value pointed to by v. +// +// See the documentation for the encoding/json Unmarshal function for details +// about the conversion of JSON to a Go value. +func (c *Conn) ReadJSON(v interface{}) error { + _, r, err := c.NextReader() + if err != nil { + return err + } + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // One value is expected in the message. + err = io.ErrUnexpectedEOF + } + return err +} diff --git a/vendor/github.com/gorilla/websocket/mask.go b/vendor/github.com/gorilla/websocket/mask.go new file mode 100644 index 000000000..d0742bf2a --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask.go @@ -0,0 +1,55 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +//go:build !appengine +// +build !appengine + +package websocket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/mask_safe.go b/vendor/github.com/gorilla/websocket/mask_safe.go new file mode 100644 index 000000000..36250ca7c --- /dev/null +++ b/vendor/github.com/gorilla/websocket/mask_safe.go @@ -0,0 +1,16 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of +// this source code is governed by a BSD-style license that can be found in the +// LICENSE file. + +//go:build appengine +// +build appengine + +package websocket + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} diff --git a/vendor/github.com/gorilla/websocket/prepared.go b/vendor/github.com/gorilla/websocket/prepared.go new file mode 100644 index 000000000..c854225e9 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/prepared.go @@ -0,0 +1,102 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan struct{}, 1) + mu <- struct{}{} + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/vendor/github.com/gorilla/websocket/proxy.go b/vendor/github.com/gorilla/websocket/proxy.go new file mode 100644 index 000000000..d716a0588 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/proxy.go @@ -0,0 +1,104 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "errors" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return fn(context.Background(), network, addr) +} + +func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return fn(ctx, network, addr) +} + +func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { + if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil + } + dialer, err := proxy.FromURL(proxyURL, forwardDial) + if err != nil { + return nil, err + } + if d, ok := dialer.(proxy.ContextDialer); ok { + return d.DialContext, nil + } + return func(ctx context.Context, net, addr string) (net.Conn, error) { + return dialer.Dial(net, addr) + }, nil +} + +type httpProxyDialer struct { + proxyURL *url.URL + forwardDial netDialerFunc +} + +func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + hostPort, _ := hostPortNoPort(hpd.proxyURL) + conn, err := hpd.forwardDial(ctx, network, hostPort) + if err != nil { + return nil, err + } + + connectHeader := make(http.Header) + if user := hpd.proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } + connectReq := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: connectHeader, + } + + if err := connectReq.Write(conn); err != nil { + conn.Close() + return nil, err + } + + // Read response. It's OK to use and discard buffered reader here because + // the remote server does not speak until spoken to. + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + conn.Close() + return nil, err + } + + // Close the response body to silence false positives from linters. Reset + // the buffered reader first to ensure that Close() does not read from + // conn. + // Note: Applications must call resp.Body.Close() on a response returned + // http.ReadResponse to inspect trailers or read another response from the + // buffered reader. The call to resp.Body.Close() does not release + // resources. + br.Reset(bytes.NewReader(nil)) + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + f := strings.SplitN(resp.Status, " ", 2) + return nil, errors.New(f[1]) + } + return conn, nil +} diff --git a/vendor/github.com/gorilla/websocket/server.go b/vendor/github.com/gorilla/websocket/server.go new file mode 100644 index 000000000..02ea01fdc --- /dev/null +++ b/vendor/github.com/gorilla/websocket/server.go @@ -0,0 +1,373 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +// Upgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +// +// It is safe to call Upgrader's methods concurrently. +type Upgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is not nil, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. If there's no match, then no protocol is + // negotiated (the Sec-Websocket-Protocol header is not included in the + // handshake response). + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. + CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(w, r, status, err) + } else { + w.Header().Set("Sec-Websocket-Version", "13") + http.Error(w, http.StatusText(status), status) + } + return nil, err +} + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return equalASCIIFold(u.Host, r.Host) +} + +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { + if u.Subprotocols != nil { + clientProtocols := Subprotocols(r) + for _, clientProtocol := range clientProtocols { + for _, serverProtocol := range u.Subprotocols { + if clientProtocol == serverProtocol { + return clientProtocol + } + } + } + } else if responseHeader != nil { + return responseHeader.Get("Sec-Websocket-Protocol") + } + return "" +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie). To specify +// subprotocols supported by the server, set Upgrader.Subprotocols directly. +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + const badHandshake = "websocket: the client is not using the websocket protocol: " + + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") + } + + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { + w.Header().Set("Upgrade", "websocket") + return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header") + } + + if r.Method != http.MethodGet { + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") + } + + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = checkSameOrigin + } + if !checkOrigin(r) { + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") + } + + challengeKey := r.Header.Get("Sec-Websocket-Key") + if !isValidChallengeKey(challengeKey) { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") + } + + subprotocol := u.selectSubprotocol(r, responseHeader) + + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + + netConn, brw, err := http.NewResponseController(w).Hijack() + if err != nil { + return u.returnError(w, r, http.StatusInternalServerError, + "websocket: hijack: "+err.Error()) + } + + // Close the network connection when returning an error. The variable + // netConn is set to nil before the success return at the end of the + // function. + defer func() { + if netConn != nil { + // It's safe to ignore the error from Close() because this code is + // only executed when returning a more important error to the + // application. + _ = netConn.Close() + } + }() + + var br *bufio.Reader + if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 { + // Use hijacked buffered reader as the connection reader. + br = brw.Reader + } else if brw.Reader.Buffered() > 0 { + // Wrap the network connection to read buffered data in brw.Reader + // before reading from the network connection. This should be rare + // because a client must not send message data before receiving the + // handshake response. + netConn = &brNetConn{br: brw.Reader, Conn: netConn} + } + + buf := brw.Writer.AvailableBuffer() + + var writeBuf []byte + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { + // Reuse hijacked write buffer as connection buffer. + writeBuf = buf + } + + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) + c.subprotocol = subprotocol + + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + + // Use larger of hijacked buffer and connection write buffer for header. + p := buf + if len(c.writeBuf) > len(p) { + p = c.writeBuf + } + p = p[:0] + + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) + p = append(p, computeAcceptKey(challengeKey)...) + p = append(p, "\r\n"...) + if c.subprotocol != "" { + p = append(p, "Sec-WebSocket-Protocol: "...) + p = append(p, c.subprotocol...) + p = append(p, "\r\n"...) + } + if compress { + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } + for k, vs := range responseHeader { + if k == "Sec-Websocket-Protocol" { + continue + } + for _, v := range vs { + p = append(p, k...) + p = append(p, ": "...) + for i := 0; i < len(v); i++ { + b := v[i] + if b <= 31 { + // prevent response splitting. + b = ' ' + } + p = append(p, b) + } + p = append(p, "\r\n"...) + } + } + p = append(p, "\r\n"...) + + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil { + return nil, err + } + } else { + // Clear deadlines set by HTTP server. + if err := netConn.SetDeadline(time.Time{}); err != nil { + return nil, err + } + } + + if _, err = netConn.Write(p); err != nil { + return nil, err + } + if u.HandshakeTimeout > 0 { + if err := netConn.SetWriteDeadline(time.Time{}); err != nil { + return nil, err + } + } + + // Success! Set netConn to nil to stop the deferred function above from + // closing the network connection. + netConn = nil + + return c, nil +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// Deprecated: Use websocket.Upgrader instead. +// +// Upgrade does not perform origin checking. The application is responsible for +// checking the Origin header before calling Upgrade. An example implementation +// of the same origin policy check is: +// +// if req.Header.Get("Origin") != "http://"+req.Host { +// http.Error(w, "Origin not allowed", http.StatusForbidden) +// return +// } +// +// If the endpoint supports subprotocols, then the application is responsible +// for negotiating the protocol used on the connection. Use the Subprotocols() +// function to get the subprotocols requested by the client. Use the +// Sec-Websocket-Protocol response header to specify the subprotocol selected +// by the application. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// negotiated subprotocol (Sec-Websocket-Protocol). +// +// The connection buffers IO to the underlying network connection. The +// readBufSize and writeBufSize parameters specify the size of the buffers to +// use. Messages can be larger than the buffers. +// +// If the request is not a valid WebSocket handshake, then Upgrade returns an +// error of type HandshakeError. Applications should handle this error by +// replying to the client with an HTTP error response. +func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { + // don't return errors to maintain backwards compatibility + } + u.CheckOrigin = func(r *http.Request) bool { + // allow all connections by default + return true + } + return u.Upgrade(w, r, responseHeader) +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-Websocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) + if h == "" { + return nil + } + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// IsWebSocketUpgrade returns true if the client requested upgrade to the +// WebSocket protocol. +func IsWebSocketUpgrade(r *http.Request) bool { + return tokenListContainsValue(r.Header, "Connection", "upgrade") && + tokenListContainsValue(r.Header, "Upgrade", "websocket") +} + +type brNetConn struct { + br *bufio.Reader + net.Conn +} + +func (b *brNetConn) Read(p []byte) (n int, err error) { + if b.br != nil { + // Limit read to buferred data. + if n := b.br.Buffered(); len(p) > n { + p = p[:n] + } + n, err = b.br.Read(p) + if b.br.Buffered() == 0 { + b.br = nil + } + return n, err + } + return b.Conn.Read(p) +} + +// NetConn returns the underlying connection that is wrapped by b. +func (b *brNetConn) NetConn() net.Conn { + return b.Conn +} + diff --git a/vendor/github.com/gorilla/websocket/util.go b/vendor/github.com/gorilla/websocket/util.go new file mode 100644 index 000000000..31a5dee64 --- /dev/null +++ b/vendor/github.com/gorilla/websocket/util.go @@ -0,0 +1,298 @@ +// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "io" + "net/http" + "strings" + "unicode/utf8" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// Token octets per RFC 2616. +var isTokenOctet = [256]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// skipSpace returns a slice of the string s with all leading RFC 2616 linear +// whitespace removed. +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if b := s[i]; b != ' ' && b != '\t' { + break + } + } + return s[i:] +} + +// nextToken returns the leading RFC 2616 token of s and the string following +// the token. +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if !isTokenOctet[s[i]] { + break + } + } + return s[:i], s[i:] +} + +// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 +// and the string following the token or quoted string. +func nextTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return nextToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +// tokenListContainsValue returns true if the 1#token header with the given +// name contains a token equal to value with ASCII case folding. +func tokenListContainsValue(header http.Header, name string, value string) bool { +headers: + for _, s := range header[name] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + continue headers + } + if equalASCIIFold(t, value) { + return true + } + if s == "" { + continue headers + } + s = s[1:] + } + } + return false +} + +// parseExtensions parses WebSocket extensions from a header. +func parseExtensions(header http.Header) []map[string]string { + // From RFC 6455: + // + // Sec-WebSocket-Extensions = extension-list + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ;When using the quoted-string syntax variant, the value + // ;after quoted-string unescaping MUST conform to the + // ;'token' ABNF. + + var result []map[string]string +headers: + for _, s := range header["Sec-Websocket-Extensions"] { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + continue headers + } + ext := map[string]string{"": t} + for { + s = skipSpace(s) + if !strings.HasPrefix(s, ";") { + break + } + var k string + k, s = nextToken(skipSpace(s[1:])) + if k == "" { + continue headers + } + s = skipSpace(s) + var v string + if strings.HasPrefix(s, "=") { + v, s = nextTokenOrQuoted(skipSpace(s[1:])) + s = skipSpace(s) + } + if s != "" && s[0] != ',' && s[0] != ';' { + continue headers + } + ext[k] = v + } + if s != "" && s[0] != ',' { + continue headers + } + result = append(result, ext) + if s == "" { + continue headers + } + s = s[1:] + } + } + return result +} + +// isValidChallengeKey checks if the argument meets RFC6455 specification. +func isValidChallengeKey(s string) bool { + // From RFC6455: + // + // A |Sec-WebSocket-Key| header field with a base64-encoded (see + // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + // length. + + if s == "" { + return false + } + decoded, err := base64.StdEncoding.DecodeString(s) + return err == nil && len(decoded) == 16 +} diff --git a/vendor/github.com/moby/spdystream/CONTRIBUTING.md b/vendor/github.com/moby/spdystream/CONTRIBUTING.md new file mode 100644 index 000000000..d4eddcc53 --- /dev/null +++ b/vendor/github.com/moby/spdystream/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing to SpdyStream + +Want to hack on spdystream? Awesome! Here are instructions to get you +started. + +SpdyStream is a part of the [Docker](https://docker.io) project, and follows +the same rules and principles. If you're already familiar with the way +Docker does things, you'll feel right at home. + +Otherwise, go read +[Docker's contributions guidelines](https://github.com/dotcloud/docker/blob/master/CONTRIBUTING.md). + +Happy hacking! diff --git a/vendor/github.com/moby/spdystream/LICENSE b/vendor/github.com/moby/spdystream/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/vendor/github.com/moby/spdystream/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/moby/spdystream/MAINTAINERS b/vendor/github.com/moby/spdystream/MAINTAINERS new file mode 100644 index 000000000..26e5ec828 --- /dev/null +++ b/vendor/github.com/moby/spdystream/MAINTAINERS @@ -0,0 +1,40 @@ +# Spdystream maintainers file +# +# This file describes who runs the moby/spdystream project and how. +# This is a living document - if you see something out of date or missing, speak up! +# +# It is structured to be consumable by both humans and programs. +# To extract its contents programmatically, use any TOML-compliant parser. +# +# This file is compiled into the MAINTAINERS file in docker/opensource. +# +[Org] + [Org."Core maintainers"] + people = [ + "adisky", + "dims", + "dmcgowan", + ] + +[people] + +# A reference list of all people associated with the project. +# All other sections should refer to people by their canonical key +# in the people section. + + # ADD YOURSELF HERE IN ALPHABETICAL ORDER + + [people.adisky] + Name = "Aditi Sharma" + Email = "adi.sky17@gmail.com" + GitHub = "adisky" + + [people.dims] + Name = "Davanum Srinivas" + Email = "davanum@gmail.com" + GitHub = "dims" + + [people.dmcgowan] + Name = "Derek McGowan" + Email = "derek@mcg.dev" + GitHub = "dmcgowan" diff --git a/vendor/github.com/moby/spdystream/NOTICE b/vendor/github.com/moby/spdystream/NOTICE new file mode 100644 index 000000000..24e2e2aa3 --- /dev/null +++ b/vendor/github.com/moby/spdystream/NOTICE @@ -0,0 +1,17 @@ +SpdyStream +Copyright 2014-2021 Docker Inc. + +This product includes software developed at +Docker Inc. (https://www.docker.com/). + +SPDY implementation (spdy/) + +The spdy directory contains code derived from the Go project (golang.org/x/net). + +Copyright 2009-2013 The Go Authors. +Licensed under the BSD 3-Clause License. + +Modifications Copyright 2014-2021 Docker Inc. + +The BSD license text and Go patent grant are included in +spdy/LICENSE and spdy/PATENTS. diff --git a/vendor/github.com/moby/spdystream/README.md b/vendor/github.com/moby/spdystream/README.md new file mode 100644 index 000000000..b84e98343 --- /dev/null +++ b/vendor/github.com/moby/spdystream/README.md @@ -0,0 +1,77 @@ +# SpdyStream + +A multiplexed stream library using spdy + +## Usage + +Client example (connecting to mirroring server without auth) + +```go +package main + +import ( + "fmt" + "github.com/moby/spdystream" + "net" + "net/http" +) + +func main() { + conn, err := net.Dial("tcp", "localhost:8080") + if err != nil { + panic(err) + } + spdyConn, err := spdystream.NewConnection(conn, false) + if err != nil { + panic(err) + } + go spdyConn.Serve(spdystream.NoOpStreamHandler) + stream, err := spdyConn.CreateStream(http.Header{}, nil, false) + if err != nil { + panic(err) + } + + stream.Wait() + + fmt.Fprint(stream, "Writing to stream") + + buf := make([]byte, 25) + stream.Read(buf) + fmt.Println(string(buf)) + + stream.Close() +} +``` + +Server example (mirroring server without auth) + +```go +package main + +import ( + "github.com/moby/spdystream" + "net" +) + +func main() { + listener, err := net.Listen("tcp", "localhost:8080") + if err != nil { + panic(err) + } + for { + conn, err := listener.Accept() + if err != nil { + panic(err) + } + spdyConn, err := spdystream.NewConnection(conn, true) + if err != nil { + panic(err) + } + go spdyConn.Serve(spdystream.MirrorStreamHandler) + } +} +``` + +## Copyright and license + +Copyright 2013-2021 Docker, inc. Released under the [Apache 2.0 license](LICENSE). diff --git a/vendor/github.com/moby/spdystream/connection.go b/vendor/github.com/moby/spdystream/connection.go new file mode 100644 index 000000000..69ce4777e --- /dev/null +++ b/vendor/github.com/moby/spdystream/connection.go @@ -0,0 +1,1000 @@ +/* + Copyright 2014-2021 Docker Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package spdystream + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/moby/spdystream/spdy" +) + +var ( + ErrInvalidStreamId = errors.New("Invalid stream id") + ErrTimeout = errors.New("Timeout occurred") + ErrReset = errors.New("Stream reset") + ErrWriteClosedStream = errors.New("Write on closed stream") +) + +const ( + FRAME_WORKERS = 5 + QUEUE_SIZE = 50 +) + +type StreamHandler func(stream *Stream) + +type AuthHandler func(header http.Header, slot uint8, parent uint32) bool + +type idleAwareFramer struct { + f *spdy.Framer + conn *Connection + writeLock sync.Mutex + resetChan chan struct{} + setTimeoutLock sync.Mutex + setTimeoutChan chan time.Duration + timeout time.Duration +} + +func newIdleAwareFramer(framer *spdy.Framer) *idleAwareFramer { + iaf := &idleAwareFramer{ + f: framer, + resetChan: make(chan struct{}, 2), + // setTimeoutChan needs to be buffered to avoid deadlocks when calling setIdleTimeout at about + // the same time the connection is being closed + setTimeoutChan: make(chan time.Duration, 1), + } + return iaf +} + +func (i *idleAwareFramer) monitor() { + var ( + timer *time.Timer + expired <-chan time.Time + resetChan = i.resetChan + setTimeoutChan = i.setTimeoutChan + ) +Loop: + for { + select { + case timeout := <-i.setTimeoutChan: + i.timeout = timeout + if timeout == 0 { + if timer != nil { + timer.Stop() + } + } else { + if timer == nil { + timer = time.NewTimer(timeout) + expired = timer.C + } else { + timer.Reset(timeout) + } + } + case <-resetChan: + if timer != nil && i.timeout > 0 { + timer.Reset(i.timeout) + } + case <-expired: + i.conn.streamCond.L.Lock() + streams := i.conn.streams + i.conn.streams = make(map[spdy.StreamId]*Stream) + i.conn.streamCond.Broadcast() + i.conn.streamCond.L.Unlock() + go func() { + for _, stream := range streams { + stream.resetStream() + } + i.conn.Close() + }() + case <-i.conn.closeChan: + if timer != nil { + timer.Stop() + } + + // Start a goroutine to drain resetChan. This is needed because we've seen + // some unit tests with large numbers of goroutines get into a situation + // where resetChan fills up, at least 1 call to Write() is still trying to + // send to resetChan, the connection gets closed, and this case statement + // attempts to grab the write lock that Write() already has, causing a + // deadlock. + // + // See https://github.com/moby/spdystream/issues/49 for more details. + go func() { + for range resetChan { + } + }() + + go func() { + for range setTimeoutChan { + } + }() + + i.writeLock.Lock() + close(resetChan) + i.resetChan = nil + i.writeLock.Unlock() + + i.setTimeoutLock.Lock() + close(i.setTimeoutChan) + i.setTimeoutChan = nil + i.setTimeoutLock.Unlock() + + break Loop + } + } + + // Drain resetChan + for range resetChan { + } +} + +func (i *idleAwareFramer) WriteFrame(frame spdy.Frame) error { + i.writeLock.Lock() + defer i.writeLock.Unlock() + if i.resetChan == nil { + return io.EOF + } + err := i.f.WriteFrame(frame) + if err != nil { + return err + } + + i.resetChan <- struct{}{} + + return nil +} + +func (i *idleAwareFramer) ReadFrame() (spdy.Frame, error) { + frame, err := i.f.ReadFrame() + if err != nil { + return nil, err + } + + // resetChan should never be closed since it is only closed + // when the connection has closed its closeChan. This closure + // only occurs after all Reads have finished + // TODO (dmcgowan): refactor relationship into connection + i.resetChan <- struct{}{} + + return frame, nil +} + +func (i *idleAwareFramer) setIdleTimeout(timeout time.Duration) { + i.setTimeoutLock.Lock() + defer i.setTimeoutLock.Unlock() + + if i.setTimeoutChan == nil { + return + } + + i.setTimeoutChan <- timeout +} + +type Connection struct { + conn net.Conn + framer *idleAwareFramer + + closeChan chan bool + goneAway bool + lastStreamChan chan<- *Stream + goAwayTimeout time.Duration + closeTimeout time.Duration + + streamLock *sync.RWMutex + streamCond *sync.Cond + streams map[spdy.StreamId]*Stream + + nextIdLock sync.Mutex + receiveIdLock sync.Mutex + nextStreamId spdy.StreamId + receivedStreamId spdy.StreamId + + // pingLock protects pingChans and pingId + pingLock sync.Mutex + pingId uint32 + pingChans map[uint32]chan error + + shutdownLock sync.Mutex + shutdownChan chan error + hasShutdown bool + + // for testing https://github.com/moby/spdystream/pull/56 + dataFrameHandler func(*spdy.DataFrame) error +} + +// NewConnection creates a new spdy connection from an existing +// network connection. +func NewConnection(conn net.Conn, server bool) (*Connection, error) { + return NewConnectionWithOptions(conn, server) +} + +// NewConnectionWithOptions creates a new spdy connection and applies frame +// parsing limits via options. +func NewConnectionWithOptions(conn net.Conn, server bool, opts ...spdy.FramerOption) (*Connection, error) { + framer, framerErr := spdy.NewFramerWithOptions(conn, conn, opts...) + if framerErr != nil { + return nil, framerErr + } + idleAwareFramer := newIdleAwareFramer(framer) + var sid spdy.StreamId + var rid spdy.StreamId + var pid uint32 + if server { + sid = 2 + rid = 1 + pid = 2 + } else { + sid = 1 + rid = 2 + pid = 1 + } + + streamLock := new(sync.RWMutex) + streamCond := sync.NewCond(streamLock) + + session := &Connection{ + conn: conn, + framer: idleAwareFramer, + + closeChan: make(chan bool), + goAwayTimeout: time.Duration(0), + closeTimeout: time.Duration(0), + + streamLock: streamLock, + streamCond: streamCond, + streams: make(map[spdy.StreamId]*Stream), + nextStreamId: sid, + receivedStreamId: rid, + + pingId: pid, + pingChans: make(map[uint32]chan error), + + shutdownChan: make(chan error), + } + session.dataFrameHandler = session.handleDataFrame + idleAwareFramer.conn = session + go idleAwareFramer.monitor() + + return session, nil +} + +// Ping sends a ping frame across the connection and +// returns the response time +func (s *Connection) Ping() (time.Duration, error) { + pid := s.pingId + s.pingLock.Lock() + if s.pingId > 0x7ffffffe { + s.pingId = s.pingId - 0x7ffffffe + } else { + s.pingId = s.pingId + 2 + } + pingChan := make(chan error) + s.pingChans[pid] = pingChan + s.pingLock.Unlock() + defer func() { + s.pingLock.Lock() + delete(s.pingChans, pid) + s.pingLock.Unlock() + }() + + frame := &spdy.PingFrame{Id: pid} + startTime := time.Now() + writeErr := s.framer.WriteFrame(frame) + if writeErr != nil { + return time.Duration(0), writeErr + } + select { + case <-s.closeChan: + return time.Duration(0), errors.New("connection closed") + case err, ok := <-pingChan: + if ok && err != nil { + return time.Duration(0), err + } + break + } + return time.Since(startTime), nil +} + +// Serve handles frames sent from the server, including reply frames +// which are needed to fully initiate connections. Both clients and servers +// should call Serve in a separate goroutine before creating streams. +func (s *Connection) Serve(newHandler StreamHandler) { + // use a WaitGroup to wait for all frames to be drained after receiving + // go-away. + var wg sync.WaitGroup + + // Parition queues to ensure stream frames are handled + // by the same worker, ensuring order is maintained + frameQueues := make([]*PriorityFrameQueue, FRAME_WORKERS) + for i := 0; i < FRAME_WORKERS; i++ { + frameQueues[i] = NewPriorityFrameQueue(QUEUE_SIZE) + + // Ensure frame queue is drained when connection is closed + go func(frameQueue *PriorityFrameQueue) { + <-s.closeChan + frameQueue.Drain() + }(frameQueues[i]) + + wg.Add(1) + go func(frameQueue *PriorityFrameQueue) { + // let the WaitGroup know this worker is done + defer wg.Done() + + s.frameHandler(frameQueue, newHandler) + }(frameQueues[i]) + } + + var ( + partitionRoundRobin int + goAwayFrame *spdy.GoAwayFrame + ) +Loop: + for { + readFrame, err := s.framer.ReadFrame() + if err != nil { + if err != io.EOF { + debugMessage("frame read error: %s", err) + } else { + debugMessage("(%p) EOF received", s) + } + if spdyErr, ok := err.(*spdy.Error); ok && spdyErr.Err == spdy.InvalidControlFrame { + _ = s.conn.Close() + } + break + } + var priority uint8 + var partition int + switch frame := readFrame.(type) { + case *spdy.SynStreamFrame: + if s.checkStreamFrame(frame) { + priority = frame.Priority + partition = int(frame.StreamId % FRAME_WORKERS) + debugMessage("(%p) Add stream frame: %d ", s, frame.StreamId) + s.addStreamFrame(frame) + } else { + debugMessage("(%p) Rejected stream frame: %d ", s, frame.StreamId) + continue + } + case *spdy.SynReplyFrame: + priority = s.getStreamPriority(frame.StreamId) + partition = int(frame.StreamId % FRAME_WORKERS) + case *spdy.DataFrame: + priority = s.getStreamPriority(frame.StreamId) + partition = int(frame.StreamId % FRAME_WORKERS) + case *spdy.RstStreamFrame: + priority = s.getStreamPriority(frame.StreamId) + partition = int(frame.StreamId % FRAME_WORKERS) + case *spdy.HeadersFrame: + priority = s.getStreamPriority(frame.StreamId) + partition = int(frame.StreamId % FRAME_WORKERS) + case *spdy.PingFrame: + priority = 0 + partition = partitionRoundRobin + partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS + case *spdy.GoAwayFrame: + // hold on to the go away frame and exit the loop + goAwayFrame = frame + break Loop + default: + priority = 7 + partition = partitionRoundRobin + partitionRoundRobin = (partitionRoundRobin + 1) % FRAME_WORKERS + } + frameQueues[partition].Push(readFrame, priority) + } + close(s.closeChan) + + // wait for all frame handler workers to indicate they've drained their queues + // before handling the go away frame + wg.Wait() + + if goAwayFrame != nil { + s.handleGoAwayFrame(goAwayFrame) + } + + // now it's safe to close remote channels and empty s.streams + s.streamCond.L.Lock() + // notify streams that they're now closed, which will + // unblock any stream Read() calls + for _, stream := range s.streams { + stream.closeRemoteChannels() + } + s.streams = make(map[spdy.StreamId]*Stream) + s.streamCond.Broadcast() + s.streamCond.L.Unlock() +} + +func (s *Connection) frameHandler(frameQueue *PriorityFrameQueue, newHandler StreamHandler) { + for { + popFrame := frameQueue.Pop() + if popFrame == nil { + return + } + + var frameErr error + switch frame := popFrame.(type) { + case *spdy.SynStreamFrame: + frameErr = s.handleStreamFrame(frame, newHandler) + case *spdy.SynReplyFrame: + frameErr = s.handleReplyFrame(frame) + case *spdy.DataFrame: + frameErr = s.dataFrameHandler(frame) + case *spdy.RstStreamFrame: + frameErr = s.handleResetFrame(frame) + case *spdy.HeadersFrame: + frameErr = s.handleHeaderFrame(frame) + case *spdy.PingFrame: + frameErr = s.handlePingFrame(frame) + case *spdy.GoAwayFrame: + frameErr = s.handleGoAwayFrame(frame) + default: + frameErr = fmt.Errorf("unhandled frame type: %T", frame) + } + + if frameErr != nil { + debugMessage("frame handling error: %s", frameErr) + } + } +} + +func (s *Connection) getStreamPriority(streamId spdy.StreamId) uint8 { + stream, streamOk := s.getStream(streamId) + if !streamOk { + return 7 + } + return stream.priority +} + +func (s *Connection) addStreamFrame(frame *spdy.SynStreamFrame) { + var parent *Stream + if frame.AssociatedToStreamId != spdy.StreamId(0) { + parent, _ = s.getStream(frame.AssociatedToStreamId) + } + + stream := &Stream{ + streamId: frame.StreamId, + parent: parent, + conn: s, + startChan: make(chan error), + headers: frame.Headers, + finished: (frame.CFHeader.Flags & spdy.ControlFlagUnidirectional) != 0x00, + replyCond: sync.NewCond(new(sync.Mutex)), + dataChan: make(chan []byte), + headerChan: make(chan http.Header), + closeChan: make(chan bool), + priority: frame.Priority, + } + if frame.CFHeader.Flags&spdy.ControlFlagFin != 0x00 { + stream.closeRemoteChannels() + } + + s.addStream(stream) +} + +// checkStreamFrame checks to see if a stream frame is allowed. +// If the stream is invalid, then a reset frame with protocol error +// will be returned. +func (s *Connection) checkStreamFrame(frame *spdy.SynStreamFrame) bool { + s.receiveIdLock.Lock() + defer s.receiveIdLock.Unlock() + if s.goneAway { + return false + } + validationErr := s.validateStreamId(frame.StreamId) + if validationErr != nil { + go func() { + resetErr := s.sendResetFrame(spdy.ProtocolError, frame.StreamId) + if resetErr != nil { + debugMessage("reset error: %s", resetErr) + } + }() + return false + } + return true +} + +func (s *Connection) handleStreamFrame(frame *spdy.SynStreamFrame, newHandler StreamHandler) error { + stream, ok := s.getStream(frame.StreamId) + if !ok { + return fmt.Errorf("Missing stream: %d", frame.StreamId) + } + + newHandler(stream) + + return nil +} + +func (s *Connection) handleReplyFrame(frame *spdy.SynReplyFrame) error { + debugMessage("(%p) Reply frame received for %d", s, frame.StreamId) + stream, streamOk := s.getStream(frame.StreamId) + if !streamOk { + debugMessage("Reply frame gone away for %d", frame.StreamId) + // Stream has already gone away + return nil + } + if stream.replied { + // Stream has already received reply + return nil + } + stream.replied = true + + // TODO Check for error + if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 { + s.remoteStreamFinish(stream) + } + + close(stream.startChan) + + return nil +} + +func (s *Connection) handleResetFrame(frame *spdy.RstStreamFrame) error { + stream, streamOk := s.getStream(frame.StreamId) + if !streamOk { + // Stream has already been removed + return nil + } + s.removeStream(stream) + stream.closeRemoteChannels() + + if !stream.replied { + stream.replied = true + stream.startChan <- ErrReset + close(stream.startChan) + } + + stream.finishLock.Lock() + stream.finished = true + stream.finishLock.Unlock() + + return nil +} + +func (s *Connection) handleHeaderFrame(frame *spdy.HeadersFrame) error { + stream, streamOk := s.getStream(frame.StreamId) + if !streamOk { + // Stream has already gone away + return nil + } + if !stream.replied { + // No reply received...Protocol error? + return nil + } + + // TODO limit headers while not blocking (use buffered chan or goroutine?) + select { + case <-stream.closeChan: + return nil + case stream.headerChan <- frame.Headers: + } + + if (frame.CFHeader.Flags & spdy.ControlFlagFin) != 0x00 { + s.remoteStreamFinish(stream) + } + + return nil +} + +func (s *Connection) handleDataFrame(frame *spdy.DataFrame) error { + debugMessage("(%p) Data frame received for %d", s, frame.StreamId) + stream, streamOk := s.getStream(frame.StreamId) + if !streamOk { + debugMessage("(%p) Data frame gone away for %d", s, frame.StreamId) + // Stream has already gone away + return nil + } + if !stream.replied { + debugMessage("(%p) Data frame not replied %d", s, frame.StreamId) + // No reply received...Protocol error? + return nil + } + + debugMessage("(%p) (%d) Data frame handling", stream, stream.streamId) + if len(frame.Data) > 0 { + stream.dataLock.RLock() + select { + case <-stream.closeChan: + debugMessage("(%p) (%d) Data frame not sent (stream shut down)", stream, stream.streamId) + case stream.dataChan <- frame.Data: + debugMessage("(%p) (%d) Data frame sent", stream, stream.streamId) + } + stream.dataLock.RUnlock() + } + if (frame.Flags & spdy.DataFlagFin) != 0x00 { + s.remoteStreamFinish(stream) + } + return nil +} + +func (s *Connection) handlePingFrame(frame *spdy.PingFrame) error { + s.pingLock.Lock() + pingId := s.pingId + pingChan, pingOk := s.pingChans[frame.Id] + s.pingLock.Unlock() + + if pingId&0x01 != frame.Id&0x01 { + return s.framer.WriteFrame(frame) + } + if pingOk { + close(pingChan) + } + return nil +} + +func (s *Connection) handleGoAwayFrame(frame *spdy.GoAwayFrame) error { + debugMessage("(%p) Go away received", s) + s.receiveIdLock.Lock() + if s.goneAway { + s.receiveIdLock.Unlock() + return nil + } + s.goneAway = true + s.receiveIdLock.Unlock() + + if s.lastStreamChan != nil { + stream, _ := s.getStream(frame.LastGoodStreamId) + go func() { + s.lastStreamChan <- stream + }() + } + + // Do not block frame handler waiting for closure + go s.shutdown(s.goAwayTimeout) + + return nil +} + +func (s *Connection) remoteStreamFinish(stream *Stream) { + stream.closeRemoteChannels() + + stream.finishLock.Lock() + if stream.finished { + // Stream is fully closed, cleanup + s.removeStream(stream) + } + stream.finishLock.Unlock() +} + +// CreateStream creates a new spdy stream using the parameters for +// creating the stream frame. The stream frame will be sent upon +// calling this function, however this function does not wait for +// the reply frame. If waiting for the reply is desired, use +// the stream Wait or WaitTimeout function on the stream returned +// by this function. +func (s *Connection) CreateStream(headers http.Header, parent *Stream, fin bool) (*Stream, error) { + // MUST synchronize stream creation (all the way to writing the frame) + // as stream IDs **MUST** increase monotonically. + s.nextIdLock.Lock() + defer s.nextIdLock.Unlock() + + streamId := s.getNextStreamId() + if streamId == 0 { + return nil, fmt.Errorf("Unable to get new stream id") + } + + stream := &Stream{ + streamId: streamId, + parent: parent, + conn: s, + startChan: make(chan error), + headers: headers, + dataChan: make(chan []byte), + headerChan: make(chan http.Header), + closeChan: make(chan bool), + } + + debugMessage("(%p) (%p) Create stream", s, stream) + + s.addStream(stream) + + return stream, s.sendStream(stream, fin) +} + +func (s *Connection) shutdown(closeTimeout time.Duration) { + // TODO Ensure this isn't called multiple times + s.shutdownLock.Lock() + if s.hasShutdown { + s.shutdownLock.Unlock() + return + } + s.hasShutdown = true + s.shutdownLock.Unlock() + + var timeout <-chan time.Time + if closeTimeout > time.Duration(0) { + timer := time.NewTimer(closeTimeout) + defer timer.Stop() + timeout = timer.C + } + streamsClosed := make(chan bool) + + go func() { + s.streamCond.L.Lock() + for len(s.streams) > 0 { + debugMessage("Streams opened: %d, %#v", len(s.streams), s.streams) + s.streamCond.Wait() + } + s.streamCond.L.Unlock() + close(streamsClosed) + }() + + var err error + select { + case <-streamsClosed: + // No active streams, close should be safe + err = s.conn.Close() + case <-timeout: + // Force ungraceful close + err = s.conn.Close() + // Wait for cleanup to clear active streams + <-streamsClosed + } + + if err != nil { + // default to 1 second + duration := time.Second + // if a closeTimeout was given, use that, clipped to 1s-10m + if closeTimeout > time.Second { + duration = closeTimeout + } + if duration > 10*time.Minute { + duration = 10 * time.Minute + } + timer := time.NewTimer(duration) + defer timer.Stop() + select { + case s.shutdownChan <- err: + // error was handled + case <-timer.C: + debugMessage("Unhandled close error after %s: %s", duration, err) + } + } + close(s.shutdownChan) +} + +// Closes spdy connection by sending GoAway frame and initiating shutdown +func (s *Connection) Close() error { + s.receiveIdLock.Lock() + if s.goneAway { + s.receiveIdLock.Unlock() + return nil + } + s.goneAway = true + s.receiveIdLock.Unlock() + + var lastStreamId spdy.StreamId + if s.receivedStreamId > 2 { + lastStreamId = s.receivedStreamId - 2 + } + + goAwayFrame := &spdy.GoAwayFrame{ + LastGoodStreamId: lastStreamId, + Status: spdy.GoAwayOK, + } + + err := s.framer.WriteFrame(goAwayFrame) + go s.shutdown(s.closeTimeout) + if err != nil { + return err + } + + return nil +} + +// CloseWait closes the connection and waits for shutdown +// to finish. Note the underlying network Connection +// is not closed until the end of shutdown. +func (s *Connection) CloseWait() error { + closeErr := s.Close() + if closeErr != nil { + return closeErr + } + shutdownErr, ok := <-s.shutdownChan + if ok { + return shutdownErr + } + return nil +} + +// Wait waits for the connection to finish shutdown or for +// the wait timeout duration to expire. This needs to be +// called either after Close has been called or the GOAWAYFRAME +// has been received. If the wait timeout is 0, this function +// will block until shutdown finishes. If wait is never called +// and a shutdown error occurs, that error will be logged as an +// unhandled error. +func (s *Connection) Wait(waitTimeout time.Duration) error { + var timeout <-chan time.Time + if waitTimeout > time.Duration(0) { + timer := time.NewTimer(waitTimeout) + defer timer.Stop() + timeout = timer.C + } + + select { + case err, ok := <-s.shutdownChan: + if ok { + return err + } + case <-timeout: + return ErrTimeout + } + return nil +} + +// NotifyClose registers a channel to be called when the remote +// peer inidicates connection closure. The last stream to be +// received by the remote will be sent on the channel. The notify +// timeout will determine the duration between go away received +// and the connection being closed. +func (s *Connection) NotifyClose(c chan<- *Stream, timeout time.Duration) { + s.goAwayTimeout = timeout + s.lastStreamChan = c +} + +// SetCloseTimeout sets the amount of time close will wait for +// streams to finish before terminating the underlying network +// connection. Setting the timeout to 0 will cause close to +// wait forever, which is the default. +func (s *Connection) SetCloseTimeout(timeout time.Duration) { + s.closeTimeout = timeout +} + +// SetIdleTimeout sets the amount of time the connection may sit idle before +// it is forcefully terminated. +func (s *Connection) SetIdleTimeout(timeout time.Duration) { + s.framer.setIdleTimeout(timeout) +} + +func (s *Connection) sendHeaders(headers http.Header, stream *Stream, fin bool) error { + var flags spdy.ControlFlags + if fin { + flags = spdy.ControlFlagFin + } + + headerFrame := &spdy.HeadersFrame{ + StreamId: stream.streamId, + Headers: headers, + CFHeader: spdy.ControlFrameHeader{Flags: flags}, + } + + return s.framer.WriteFrame(headerFrame) +} + +func (s *Connection) sendReply(headers http.Header, stream *Stream, fin bool) error { + var flags spdy.ControlFlags + if fin { + flags = spdy.ControlFlagFin + } + + replyFrame := &spdy.SynReplyFrame{ + StreamId: stream.streamId, + Headers: headers, + CFHeader: spdy.ControlFrameHeader{Flags: flags}, + } + + return s.framer.WriteFrame(replyFrame) +} + +func (s *Connection) sendResetFrame(status spdy.RstStreamStatus, streamId spdy.StreamId) error { + resetFrame := &spdy.RstStreamFrame{ + StreamId: streamId, + Status: status, + } + + return s.framer.WriteFrame(resetFrame) +} + +func (s *Connection) sendReset(status spdy.RstStreamStatus, stream *Stream) error { + return s.sendResetFrame(status, stream.streamId) +} + +func (s *Connection) sendStream(stream *Stream, fin bool) error { + var flags spdy.ControlFlags + if fin { + flags = spdy.ControlFlagFin + stream.finished = true + } + + var parentId spdy.StreamId + if stream.parent != nil { + parentId = stream.parent.streamId + } + + streamFrame := &spdy.SynStreamFrame{ + StreamId: spdy.StreamId(stream.streamId), + AssociatedToStreamId: spdy.StreamId(parentId), + Headers: stream.headers, + CFHeader: spdy.ControlFrameHeader{Flags: flags}, + } + + return s.framer.WriteFrame(streamFrame) +} + +// getNextStreamId returns the next sequential id +// every call should produce a unique value or an error +func (s *Connection) getNextStreamId() spdy.StreamId { + sid := s.nextStreamId + if sid > 0x7fffffff { + return 0 + } + s.nextStreamId = s.nextStreamId + 2 + return sid +} + +// PeekNextStreamId returns the next sequential id and keeps the next id untouched +func (s *Connection) PeekNextStreamId() spdy.StreamId { + sid := s.nextStreamId + return sid +} + +func (s *Connection) validateStreamId(rid spdy.StreamId) error { + if rid > 0x7fffffff || rid < s.receivedStreamId { + return ErrInvalidStreamId + } + s.receivedStreamId = rid + 2 + return nil +} + +func (s *Connection) addStream(stream *Stream) { + s.streamCond.L.Lock() + s.streams[stream.streamId] = stream + debugMessage("(%p) (%p) Stream added, broadcasting: %d", s, stream, stream.streamId) + s.streamCond.Broadcast() + s.streamCond.L.Unlock() +} + +func (s *Connection) removeStream(stream *Stream) { + s.streamCond.L.Lock() + delete(s.streams, stream.streamId) + debugMessage("(%p) (%p) Stream removed, broadcasting: %d", s, stream, stream.streamId) + s.streamCond.Broadcast() + s.streamCond.L.Unlock() +} + +func (s *Connection) getStream(streamId spdy.StreamId) (stream *Stream, ok bool) { + s.streamLock.RLock() + stream, ok = s.streams[streamId] + s.streamLock.RUnlock() + return +} + +// FindStream looks up the given stream id and either waits for the +// stream to be found or returns nil if the stream id is no longer +// valid. +func (s *Connection) FindStream(streamId uint32) *Stream { + var stream *Stream + var ok bool + s.streamCond.L.Lock() + stream, ok = s.streams[spdy.StreamId(streamId)] + debugMessage("(%p) Found stream %d? %t", s, spdy.StreamId(streamId), ok) + for !ok && streamId >= uint32(s.receivedStreamId) { + s.streamCond.Wait() + stream, ok = s.streams[spdy.StreamId(streamId)] + } + s.streamCond.L.Unlock() + return stream +} + +func (s *Connection) CloseChan() <-chan bool { + return s.closeChan +} diff --git a/vendor/github.com/moby/spdystream/handlers.go b/vendor/github.com/moby/spdystream/handlers.go new file mode 100644 index 000000000..d68f61f81 --- /dev/null +++ b/vendor/github.com/moby/spdystream/handlers.go @@ -0,0 +1,52 @@ +/* + Copyright 2014-2021 Docker Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package spdystream + +import ( + "io" + "net/http" +) + +// MirrorStreamHandler mirrors all streams. +func MirrorStreamHandler(stream *Stream) { + replyErr := stream.SendReply(http.Header{}, false) + if replyErr != nil { + return + } + + go func() { + io.Copy(stream, stream) + stream.Close() + }() + go func() { + for { + header, receiveErr := stream.ReceiveHeader() + if receiveErr != nil { + return + } + sendErr := stream.SendHeader(header, false) + if sendErr != nil { + return + } + } + }() +} + +// NoopStreamHandler does nothing when stream connects. +func NoOpStreamHandler(stream *Stream) { + stream.SendReply(http.Header{}, false) +} diff --git a/vendor/github.com/moby/spdystream/priority.go b/vendor/github.com/moby/spdystream/priority.go new file mode 100644 index 000000000..d8eb3516c --- /dev/null +++ b/vendor/github.com/moby/spdystream/priority.go @@ -0,0 +1,114 @@ +/* + Copyright 2014-2021 Docker Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package spdystream + +import ( + "container/heap" + "sync" + + "github.com/moby/spdystream/spdy" +) + +type prioritizedFrame struct { + frame spdy.Frame + priority uint8 + insertId uint64 +} + +type frameQueue []*prioritizedFrame + +func (fq frameQueue) Len() int { + return len(fq) +} + +func (fq frameQueue) Less(i, j int) bool { + if fq[i].priority == fq[j].priority { + return fq[i].insertId < fq[j].insertId + } + return fq[i].priority < fq[j].priority +} + +func (fq frameQueue) Swap(i, j int) { + fq[i], fq[j] = fq[j], fq[i] +} + +func (fq *frameQueue) Push(x interface{}) { + *fq = append(*fq, x.(*prioritizedFrame)) +} + +func (fq *frameQueue) Pop() interface{} { + old := *fq + n := len(old) + *fq = old[0 : n-1] + return old[n-1] +} + +type PriorityFrameQueue struct { + queue *frameQueue + c *sync.Cond + size int + nextInsertId uint64 + drain bool +} + +func NewPriorityFrameQueue(size int) *PriorityFrameQueue { + queue := make(frameQueue, 0, size) + heap.Init(&queue) + + return &PriorityFrameQueue{ + queue: &queue, + size: size, + c: sync.NewCond(&sync.Mutex{}), + } +} + +func (q *PriorityFrameQueue) Push(frame spdy.Frame, priority uint8) { + q.c.L.Lock() + defer q.c.L.Unlock() + for q.queue.Len() >= q.size { + q.c.Wait() + } + pFrame := &prioritizedFrame{ + frame: frame, + priority: priority, + insertId: q.nextInsertId, + } + q.nextInsertId = q.nextInsertId + 1 + heap.Push(q.queue, pFrame) + q.c.Signal() +} + +func (q *PriorityFrameQueue) Pop() spdy.Frame { + q.c.L.Lock() + defer q.c.L.Unlock() + for q.queue.Len() == 0 { + if q.drain { + return nil + } + q.c.Wait() + } + frame := heap.Pop(q.queue).(*prioritizedFrame).frame + q.c.Signal() + return frame +} + +func (q *PriorityFrameQueue) Drain() { + q.c.L.Lock() + defer q.c.L.Unlock() + q.drain = true + q.c.Broadcast() +} diff --git a/vendor/github.com/moby/spdystream/spdy/LICENSE b/vendor/github.com/moby/spdystream/spdy/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/moby/spdystream/spdy/PATENTS b/vendor/github.com/moby/spdystream/spdy/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/github.com/moby/spdystream/spdy/dictionary.go b/vendor/github.com/moby/spdystream/spdy/dictionary.go new file mode 100644 index 000000000..5a5ff0e14 --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/dictionary.go @@ -0,0 +1,187 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +// headerDictionary is the dictionary sent to the zlib compressor/decompressor. +var headerDictionary = []byte{ + 0x00, 0x00, 0x00, 0x07, 0x6f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x00, 0x00, 0x00, 0x04, 0x68, + 0x65, 0x61, 0x64, 0x00, 0x00, 0x00, 0x04, 0x70, + 0x6f, 0x73, 0x74, 0x00, 0x00, 0x00, 0x03, 0x70, + 0x75, 0x74, 0x00, 0x00, 0x00, 0x06, 0x64, 0x65, + 0x6c, 0x65, 0x74, 0x65, 0x00, 0x00, 0x00, 0x05, + 0x74, 0x72, 0x61, 0x63, 0x65, 0x00, 0x00, 0x00, + 0x06, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x00, + 0x00, 0x00, 0x0e, 0x61, 0x63, 0x63, 0x65, 0x70, + 0x74, 0x2d, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, + 0x74, 0x00, 0x00, 0x00, 0x0f, 0x61, 0x63, 0x63, + 0x65, 0x70, 0x74, 0x2d, 0x65, 0x6e, 0x63, 0x6f, + 0x64, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x0f, + 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, + 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x00, + 0x00, 0x00, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x70, + 0x74, 0x2d, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x00, 0x00, 0x00, 0x03, 0x61, 0x67, 0x65, 0x00, + 0x00, 0x00, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, + 0x00, 0x00, 0x00, 0x0d, 0x61, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x00, 0x00, 0x00, 0x0d, 0x63, 0x61, 0x63, + 0x68, 0x65, 0x2d, 0x63, 0x6f, 0x6e, 0x74, 0x72, + 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x0a, 0x63, 0x6f, + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x00, 0x00, 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x62, 0x61, 0x73, 0x65, + 0x00, 0x00, 0x00, 0x10, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x65, 0x6e, 0x63, 0x6f, + 0x64, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x10, + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, + 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, + 0x00, 0x00, 0x00, 0x0e, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, + 0x74, 0x68, 0x00, 0x00, 0x00, 0x10, 0x63, 0x6f, + 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x6f, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, + 0x00, 0x0b, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, + 0x74, 0x2d, 0x6d, 0x64, 0x35, 0x00, 0x00, 0x00, + 0x0d, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, + 0x2d, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x00, 0x00, + 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, + 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x00, 0x00, + 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x00, 0x00, + 0x00, 0x04, 0x65, 0x74, 0x61, 0x67, 0x00, 0x00, + 0x00, 0x06, 0x65, 0x78, 0x70, 0x65, 0x63, 0x74, + 0x00, 0x00, 0x00, 0x07, 0x65, 0x78, 0x70, 0x69, + 0x72, 0x65, 0x73, 0x00, 0x00, 0x00, 0x04, 0x66, + 0x72, 0x6f, 0x6d, 0x00, 0x00, 0x00, 0x04, 0x68, + 0x6f, 0x73, 0x74, 0x00, 0x00, 0x00, 0x08, 0x69, + 0x66, 0x2d, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x00, + 0x00, 0x00, 0x11, 0x69, 0x66, 0x2d, 0x6d, 0x6f, + 0x64, 0x69, 0x66, 0x69, 0x65, 0x64, 0x2d, 0x73, + 0x69, 0x6e, 0x63, 0x65, 0x00, 0x00, 0x00, 0x0d, + 0x69, 0x66, 0x2d, 0x6e, 0x6f, 0x6e, 0x65, 0x2d, + 0x6d, 0x61, 0x74, 0x63, 0x68, 0x00, 0x00, 0x00, + 0x08, 0x69, 0x66, 0x2d, 0x72, 0x61, 0x6e, 0x67, + 0x65, 0x00, 0x00, 0x00, 0x13, 0x69, 0x66, 0x2d, + 0x75, 0x6e, 0x6d, 0x6f, 0x64, 0x69, 0x66, 0x69, + 0x65, 0x64, 0x2d, 0x73, 0x69, 0x6e, 0x63, 0x65, + 0x00, 0x00, 0x00, 0x0d, 0x6c, 0x61, 0x73, 0x74, + 0x2d, 0x6d, 0x6f, 0x64, 0x69, 0x66, 0x69, 0x65, + 0x64, 0x00, 0x00, 0x00, 0x08, 0x6c, 0x6f, 0x63, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, 0x00, + 0x0c, 0x6d, 0x61, 0x78, 0x2d, 0x66, 0x6f, 0x72, + 0x77, 0x61, 0x72, 0x64, 0x73, 0x00, 0x00, 0x00, + 0x06, 0x70, 0x72, 0x61, 0x67, 0x6d, 0x61, 0x00, + 0x00, 0x00, 0x12, 0x70, 0x72, 0x6f, 0x78, 0x79, + 0x2d, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, + 0x69, 0x63, 0x61, 0x74, 0x65, 0x00, 0x00, 0x00, + 0x13, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2d, 0x61, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, 0x00, 0x05, + 0x72, 0x61, 0x6e, 0x67, 0x65, 0x00, 0x00, 0x00, + 0x07, 0x72, 0x65, 0x66, 0x65, 0x72, 0x65, 0x72, + 0x00, 0x00, 0x00, 0x0b, 0x72, 0x65, 0x74, 0x72, + 0x79, 0x2d, 0x61, 0x66, 0x74, 0x65, 0x72, 0x00, + 0x00, 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x00, 0x00, 0x00, 0x02, 0x74, 0x65, 0x00, + 0x00, 0x00, 0x07, 0x74, 0x72, 0x61, 0x69, 0x6c, + 0x65, 0x72, 0x00, 0x00, 0x00, 0x11, 0x74, 0x72, + 0x61, 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2d, 0x65, + 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x00, + 0x00, 0x00, 0x07, 0x75, 0x70, 0x67, 0x72, 0x61, + 0x64, 0x65, 0x00, 0x00, 0x00, 0x0a, 0x75, 0x73, + 0x65, 0x72, 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x00, 0x00, 0x00, 0x04, 0x76, 0x61, 0x72, 0x79, + 0x00, 0x00, 0x00, 0x03, 0x76, 0x69, 0x61, 0x00, + 0x00, 0x00, 0x07, 0x77, 0x61, 0x72, 0x6e, 0x69, + 0x6e, 0x67, 0x00, 0x00, 0x00, 0x10, 0x77, 0x77, + 0x77, 0x2d, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, + 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x00, 0x00, + 0x00, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, + 0x00, 0x00, 0x00, 0x03, 0x67, 0x65, 0x74, 0x00, + 0x00, 0x00, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x00, 0x00, 0x00, 0x06, 0x32, 0x30, 0x30, + 0x20, 0x4f, 0x4b, 0x00, 0x00, 0x00, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x00, + 0x00, 0x08, 0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, + 0x2e, 0x31, 0x00, 0x00, 0x00, 0x03, 0x75, 0x72, + 0x6c, 0x00, 0x00, 0x00, 0x06, 0x70, 0x75, 0x62, + 0x6c, 0x69, 0x63, 0x00, 0x00, 0x00, 0x0a, 0x73, + 0x65, 0x74, 0x2d, 0x63, 0x6f, 0x6f, 0x6b, 0x69, + 0x65, 0x00, 0x00, 0x00, 0x0a, 0x6b, 0x65, 0x65, + 0x70, 0x2d, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x00, + 0x00, 0x00, 0x06, 0x6f, 0x72, 0x69, 0x67, 0x69, + 0x6e, 0x31, 0x30, 0x30, 0x31, 0x30, 0x31, 0x32, + 0x30, 0x31, 0x32, 0x30, 0x32, 0x32, 0x30, 0x35, + 0x32, 0x30, 0x36, 0x33, 0x30, 0x30, 0x33, 0x30, + 0x32, 0x33, 0x30, 0x33, 0x33, 0x30, 0x34, 0x33, + 0x30, 0x35, 0x33, 0x30, 0x36, 0x33, 0x30, 0x37, + 0x34, 0x30, 0x32, 0x34, 0x30, 0x35, 0x34, 0x30, + 0x36, 0x34, 0x30, 0x37, 0x34, 0x30, 0x38, 0x34, + 0x30, 0x39, 0x34, 0x31, 0x30, 0x34, 0x31, 0x31, + 0x34, 0x31, 0x32, 0x34, 0x31, 0x33, 0x34, 0x31, + 0x34, 0x34, 0x31, 0x35, 0x34, 0x31, 0x36, 0x34, + 0x31, 0x37, 0x35, 0x30, 0x32, 0x35, 0x30, 0x34, + 0x35, 0x30, 0x35, 0x32, 0x30, 0x33, 0x20, 0x4e, + 0x6f, 0x6e, 0x2d, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, + 0x20, 0x49, 0x6e, 0x66, 0x6f, 0x72, 0x6d, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x32, 0x30, 0x34, 0x20, + 0x4e, 0x6f, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x65, + 0x6e, 0x74, 0x33, 0x30, 0x31, 0x20, 0x4d, 0x6f, + 0x76, 0x65, 0x64, 0x20, 0x50, 0x65, 0x72, 0x6d, + 0x61, 0x6e, 0x65, 0x6e, 0x74, 0x6c, 0x79, 0x34, + 0x30, 0x30, 0x20, 0x42, 0x61, 0x64, 0x20, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x34, 0x30, + 0x31, 0x20, 0x55, 0x6e, 0x61, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x34, 0x30, + 0x33, 0x20, 0x46, 0x6f, 0x72, 0x62, 0x69, 0x64, + 0x64, 0x65, 0x6e, 0x34, 0x30, 0x34, 0x20, 0x4e, + 0x6f, 0x74, 0x20, 0x46, 0x6f, 0x75, 0x6e, 0x64, + 0x35, 0x30, 0x30, 0x20, 0x49, 0x6e, 0x74, 0x65, + 0x72, 0x6e, 0x61, 0x6c, 0x20, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6f, + 0x72, 0x35, 0x30, 0x31, 0x20, 0x4e, 0x6f, 0x74, + 0x20, 0x49, 0x6d, 0x70, 0x6c, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x65, 0x64, 0x35, 0x30, 0x33, 0x20, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x20, + 0x55, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, + 0x62, 0x6c, 0x65, 0x4a, 0x61, 0x6e, 0x20, 0x46, + 0x65, 0x62, 0x20, 0x4d, 0x61, 0x72, 0x20, 0x41, + 0x70, 0x72, 0x20, 0x4d, 0x61, 0x79, 0x20, 0x4a, + 0x75, 0x6e, 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x41, + 0x75, 0x67, 0x20, 0x53, 0x65, 0x70, 0x74, 0x20, + 0x4f, 0x63, 0x74, 0x20, 0x4e, 0x6f, 0x76, 0x20, + 0x44, 0x65, 0x63, 0x20, 0x30, 0x30, 0x3a, 0x30, + 0x30, 0x3a, 0x30, 0x30, 0x20, 0x4d, 0x6f, 0x6e, + 0x2c, 0x20, 0x54, 0x75, 0x65, 0x2c, 0x20, 0x57, + 0x65, 0x64, 0x2c, 0x20, 0x54, 0x68, 0x75, 0x2c, + 0x20, 0x46, 0x72, 0x69, 0x2c, 0x20, 0x53, 0x61, + 0x74, 0x2c, 0x20, 0x53, 0x75, 0x6e, 0x2c, 0x20, + 0x47, 0x4d, 0x54, 0x63, 0x68, 0x75, 0x6e, 0x6b, + 0x65, 0x64, 0x2c, 0x74, 0x65, 0x78, 0x74, 0x2f, + 0x68, 0x74, 0x6d, 0x6c, 0x2c, 0x69, 0x6d, 0x61, + 0x67, 0x65, 0x2f, 0x70, 0x6e, 0x67, 0x2c, 0x69, + 0x6d, 0x61, 0x67, 0x65, 0x2f, 0x6a, 0x70, 0x67, + 0x2c, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x2f, 0x67, + 0x69, 0x66, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, + 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, + 0x68, 0x74, 0x6d, 0x6c, 0x2b, 0x78, 0x6d, 0x6c, + 0x2c, 0x74, 0x65, 0x78, 0x74, 0x2f, 0x70, 0x6c, + 0x61, 0x69, 0x6e, 0x2c, 0x74, 0x65, 0x78, 0x74, + 0x2f, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, + 0x69, 0x70, 0x74, 0x2c, 0x70, 0x75, 0x62, 0x6c, + 0x69, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, + 0x65, 0x6d, 0x61, 0x78, 0x2d, 0x61, 0x67, 0x65, + 0x3d, 0x67, 0x7a, 0x69, 0x70, 0x2c, 0x64, 0x65, + 0x66, 0x6c, 0x61, 0x74, 0x65, 0x2c, 0x73, 0x64, + 0x63, 0x68, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, + 0x74, 0x3d, 0x75, 0x74, 0x66, 0x2d, 0x38, 0x63, + 0x68, 0x61, 0x72, 0x73, 0x65, 0x74, 0x3d, 0x69, + 0x73, 0x6f, 0x2d, 0x38, 0x38, 0x35, 0x39, 0x2d, + 0x31, 0x2c, 0x75, 0x74, 0x66, 0x2d, 0x2c, 0x2a, + 0x2c, 0x65, 0x6e, 0x71, 0x3d, 0x30, 0x2e, +} diff --git a/vendor/github.com/moby/spdystream/spdy/options.go b/vendor/github.com/moby/spdystream/spdy/options.go new file mode 100644 index 000000000..ec03e0b9a --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/options.go @@ -0,0 +1,25 @@ +package spdy + +// FramerOption allows callers to customize frame parsing limits. +type FramerOption func(*Framer) + +// WithMaxControlFramePayloadSize sets the control-frame payload limit. +func WithMaxControlFramePayloadSize(size uint32) FramerOption { + return func(f *Framer) { + f.maxFrameLength = size + } +} + +// WithMaxHeaderFieldSize sets the per-header name/value size limit. +func WithMaxHeaderFieldSize(size uint32) FramerOption { + return func(f *Framer) { + f.maxHeaderFieldSize = size + } +} + +// WithMaxHeaderCount sets the maximum number of headers in a frame. +func WithMaxHeaderCount(count uint32) FramerOption { + return func(f *Framer) { + f.maxHeaderCount = count + } +} diff --git a/vendor/github.com/moby/spdystream/spdy/read.go b/vendor/github.com/moby/spdystream/spdy/read.go new file mode 100644 index 000000000..2abb69433 --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/read.go @@ -0,0 +1,382 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "compress/zlib" + "encoding/binary" + "io" + "io/ioutil" + "net/http" + "strings" +) + +func (frame *SynStreamFrame) read(h ControlFrameHeader, f *Framer) error { + return f.readSynStreamFrame(h, frame) +} + +func (frame *SynReplyFrame) read(h ControlFrameHeader, f *Framer) error { + return f.readSynReplyFrame(h, frame) +} + +func (frame *RstStreamFrame) read(h ControlFrameHeader, f *Framer) error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + if err := binary.Read(f.r, binary.BigEndian, &frame.Status); err != nil { + return err + } + if frame.Status == 0 { + return &Error{InvalidControlFrame, frame.StreamId} + } + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + return nil +} + +func (frame *SettingsFrame) read(h ControlFrameHeader, f *Framer) error { + frame.CFHeader = h + var numSettings uint32 + if err := binary.Read(f.r, binary.BigEndian, &numSettings); err != nil { + return err + } + // Each setting is 8 bytes (4-byte id + 4-byte value). + // Payload is 4 bytes for numSettings + numSettings*8. + if h.length < 4 || numSettings > (h.length-4)/8 { + return &Error{InvalidControlFrame, 0} + } + frame.FlagIdValues = make([]SettingsFlagIdValue, numSettings) + for i := uint32(0); i < numSettings; i++ { + if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Id); err != nil { + return err + } + frame.FlagIdValues[i].Flag = SettingsFlag((frame.FlagIdValues[i].Id & 0xff000000) >> 24) + frame.FlagIdValues[i].Id &= 0xffffff + if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Value); err != nil { + return err + } + } + return nil +} + +func (frame *PingFrame) read(h ControlFrameHeader, f *Framer) error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.Id); err != nil { + return err + } + if frame.Id == 0 { + return &Error{ZeroStreamId, 0} + } + if frame.CFHeader.Flags != 0 { + return &Error{InvalidControlFrame, StreamId(frame.Id)} + } + return nil +} + +func (frame *GoAwayFrame) read(h ControlFrameHeader, f *Framer) error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.LastGoodStreamId); err != nil { + return err + } + if frame.CFHeader.Flags != 0 { + return &Error{InvalidControlFrame, frame.LastGoodStreamId} + } + if frame.CFHeader.length != 8 { + return &Error{InvalidControlFrame, frame.LastGoodStreamId} + } + if err := binary.Read(f.r, binary.BigEndian, &frame.Status); err != nil { + return err + } + return nil +} + +func (frame *HeadersFrame) read(h ControlFrameHeader, f *Framer) error { + return f.readHeadersFrame(h, frame) +} + +func (frame *WindowUpdateFrame) read(h ControlFrameHeader, f *Framer) error { + frame.CFHeader = h + if err := binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + if frame.CFHeader.Flags != 0 { + return &Error{InvalidControlFrame, frame.StreamId} + } + if frame.CFHeader.length != 8 { + return &Error{InvalidControlFrame, frame.StreamId} + } + if err := binary.Read(f.r, binary.BigEndian, &frame.DeltaWindowSize); err != nil { + return err + } + return nil +} + +func newControlFrame(frameType ControlFrameType) (controlFrame, error) { + ctor, ok := cframeCtor[frameType] + if !ok { + return nil, &Error{Err: InvalidControlFrame} + } + return ctor(), nil +} + +var cframeCtor = map[ControlFrameType]func() controlFrame{ + TypeSynStream: func() controlFrame { return new(SynStreamFrame) }, + TypeSynReply: func() controlFrame { return new(SynReplyFrame) }, + TypeRstStream: func() controlFrame { return new(RstStreamFrame) }, + TypeSettings: func() controlFrame { return new(SettingsFrame) }, + TypePing: func() controlFrame { return new(PingFrame) }, + TypeGoAway: func() controlFrame { return new(GoAwayFrame) }, + TypeHeaders: func() controlFrame { return new(HeadersFrame) }, + TypeWindowUpdate: func() controlFrame { return new(WindowUpdateFrame) }, +} + +func (f *Framer) uncorkHeaderDecompressor(payloadSize int64) error { + if f.headerDecompressor != nil { + f.headerReader.N = payloadSize + return nil + } + f.headerReader = io.LimitedReader{R: f.r, N: payloadSize} + decompressor, err := zlib.NewReaderDict(&f.headerReader, []byte(headerDictionary)) + if err != nil { + return err + } + f.headerDecompressor = decompressor + return nil +} + +// ReadFrame reads SPDY encoded data and returns a decompressed Frame. +func (f *Framer) ReadFrame() (Frame, error) { + var firstWord uint32 + if err := binary.Read(f.r, binary.BigEndian, &firstWord); err != nil { + return nil, err + } + if firstWord&0x80000000 != 0 { + frameType := ControlFrameType(firstWord & 0xffff) + version := uint16(firstWord >> 16 & 0x7fff) + return f.parseControlFrame(version, frameType) + } + return f.parseDataFrame(StreamId(firstWord & 0x7fffffff)) +} + +func (f *Framer) parseControlFrame(version uint16, frameType ControlFrameType) (Frame, error) { + var length uint32 + if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { + return nil, err + } + maxControlFramePayload := uint32(MaxDataLength) + if f.maxFrameLength > 0 { + maxControlFramePayload = f.maxFrameLength + } + + flags := ControlFlags((length & 0xff000000) >> 24) + length &= 0xffffff + if length > maxControlFramePayload { + if _, err := io.CopyN(ioutil.Discard, f.r, int64(length)); err != nil { + return nil, err + } + return nil, &Error{InvalidControlFrame, 0} + } + header := ControlFrameHeader{version, frameType, flags, length} + cframe, err := newControlFrame(frameType) + if err != nil { + return nil, err + } + if err = cframe.read(header, f); err != nil { + return nil, err + } + return cframe, nil +} + +func (f *Framer) parseHeaderValueBlock(r io.Reader, streamId StreamId) (http.Header, error) { + var numHeaders uint32 + if err := binary.Read(r, binary.BigEndian, &numHeaders); err != nil { + return nil, err + } + maxHeaders := defaultMaxHeaderCount + if f.maxHeaderCount > 0 { + maxHeaders = f.maxHeaderCount + } + if numHeaders > maxHeaders { + return nil, &Error{InvalidControlFrame, streamId} + } + maxFieldSize := defaultMaxHeaderFieldSize + if f.maxHeaderFieldSize > 0 { + maxFieldSize = f.maxHeaderFieldSize + } + var e error + h := make(http.Header, int(numHeaders)) + for i := 0; i < int(numHeaders); i++ { + var length uint32 + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return nil, err + } + if length > maxFieldSize { + return nil, &Error{InvalidControlFrame, streamId} + } + nameBytes := make([]byte, length) + if _, err := io.ReadFull(r, nameBytes); err != nil { + return nil, err + } + name := string(nameBytes) + if name != strings.ToLower(name) { + e = &Error{UnlowercasedHeaderName, streamId} + name = strings.ToLower(name) + } + if h[name] != nil { + e = &Error{DuplicateHeaders, streamId} + } + if err := binary.Read(r, binary.BigEndian, &length); err != nil { + return nil, err + } + if length > maxFieldSize { + return nil, &Error{InvalidControlFrame, streamId} + } + value := make([]byte, length) + if _, err := io.ReadFull(r, value); err != nil { + return nil, err + } + valueList := strings.Split(string(value), headerValueSeparator) + for _, v := range valueList { + h.Add(name, v) + } + } + if e != nil { + return h, e + } + return h, nil +} + +func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame) error { + frame.CFHeader = h + var err error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + if err = binary.Read(f.r, binary.BigEndian, &frame.AssociatedToStreamId); err != nil { + return err + } + if err = binary.Read(f.r, binary.BigEndian, &frame.Priority); err != nil { + return err + } + frame.Priority >>= 5 + if err = binary.Read(f.r, binary.BigEndian, &frame.Slot); err != nil { + return err + } + reader := f.r + if !f.headerCompressionDisabled { + err := f.uncorkHeaderDecompressor(int64(h.length - 10)) + if err != nil { + return err + } + reader = f.headerDecompressor + } + frame.Headers, err = f.parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && (err == io.EOF && f.headerReader.N == 0 || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } + if err != nil { + return err + } + for h := range frame.Headers { + if invalidReqHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + return nil +} + +func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) error { + frame.CFHeader = h + var err error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + reader := f.r + if !f.headerCompressionDisabled { + err := f.uncorkHeaderDecompressor(int64(h.length - 4)) + if err != nil { + return err + } + reader = f.headerDecompressor + } + frame.Headers, err = f.parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && (err == io.EOF && f.headerReader.N == 0 || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } + if err != nil { + return err + } + for h := range frame.Headers { + if invalidRespHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + return nil +} + +func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) error { + frame.CFHeader = h + var err error + if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { + return err + } + reader := f.r + if !f.headerCompressionDisabled { + err := f.uncorkHeaderDecompressor(int64(h.length - 4)) + if err != nil { + return err + } + reader = f.headerDecompressor + } + frame.Headers, err = f.parseHeaderValueBlock(reader, frame.StreamId) + if !f.headerCompressionDisabled && (err == io.EOF && f.headerReader.N == 0 || f.headerReader.N != 0) { + err = &Error{WrongCompressedPayloadSize, 0} + } + if err != nil { + return err + } + var invalidHeaders map[string]bool + if frame.StreamId%2 == 0 { + invalidHeaders = invalidReqHeaders + } else { + invalidHeaders = invalidRespHeaders + } + for h := range frame.Headers { + if invalidHeaders[h] { + return &Error{InvalidHeaderPresent, frame.StreamId} + } + } + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + return nil +} + +func (f *Framer) parseDataFrame(streamId StreamId) (*DataFrame, error) { + var length uint32 + if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { + return nil, err + } + var frame DataFrame + frame.StreamId = streamId + frame.Flags = DataFlags(length >> 24) + length &= 0xffffff + frame.Data = make([]byte, length) + if _, err := io.ReadFull(f.r, frame.Data); err != nil { + return nil, err + } + if frame.StreamId == 0 { + return nil, &Error{ZeroStreamId, 0} + } + return &frame, nil +} diff --git a/vendor/github.com/moby/spdystream/spdy/types.go b/vendor/github.com/moby/spdystream/spdy/types.go new file mode 100644 index 000000000..a5528618c --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/types.go @@ -0,0 +1,308 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Modifications Copyright 2014-2021 Docker Inc. + +// Package spdy implements the SPDY protocol (currently SPDY/3), described in +// http://www.chromium.org/spdy/spdy-protocol/spdy-protocol-draft3. +package spdy + +import ( + "bytes" + "compress/zlib" + "io" + "net/http" +) + +// Version is the protocol version number that this package implements. +const Version = 3 + +// ControlFrameType stores the type field in a control frame header. +type ControlFrameType uint16 + +const ( + TypeSynStream ControlFrameType = 0x0001 + TypeSynReply ControlFrameType = 0x0002 + TypeRstStream ControlFrameType = 0x0003 + TypeSettings ControlFrameType = 0x0004 + TypePing ControlFrameType = 0x0006 + TypeGoAway ControlFrameType = 0x0007 + TypeHeaders ControlFrameType = 0x0008 + TypeWindowUpdate ControlFrameType = 0x0009 +) + +// ControlFlags are the flags that can be set on a control frame. +type ControlFlags uint8 + +const ( + ControlFlagFin ControlFlags = 0x01 + ControlFlagUnidirectional ControlFlags = 0x02 + ControlFlagSettingsClearSettings ControlFlags = 0x01 +) + +// DataFlags are the flags that can be set on a data frame. +type DataFlags uint8 + +const ( + DataFlagFin DataFlags = 0x01 +) + +// MaxDataLength is the maximum number of bytes that can be stored in one frame. +// +// SPDY frame headers encode the payload length using a 24-bit field, +// so the maximum representable size for both data and control frames +// is 2^24-1 bytes. +// +// See the SPDY/3 specification, "Frame Format": +// https://www.chromium.org/spdy/spdy-protocol/spdy-protocol-draft3-1/ +const MaxDataLength = 1<<24 - 1 + +const ( + defaultMaxHeaderFieldSize uint32 = 1 << 20 + defaultMaxHeaderCount uint32 = 1000 +) + +// headerValueSepator separates multiple header values. +const headerValueSeparator = "\x00" + +// Frame is a single SPDY frame in its unpacked in-memory representation. Use +// Framer to read and write it. +type Frame interface { + write(f *Framer) error +} + +// ControlFrameHeader contains all the fields in a control frame header, +// in its unpacked in-memory representation. +type ControlFrameHeader struct { + // Note, high bit is the "Control" bit. + version uint16 // spdy version number + frameType ControlFrameType + Flags ControlFlags + length uint32 // length of data field +} + +type controlFrame interface { + Frame + read(h ControlFrameHeader, f *Framer) error +} + +// StreamId represents a 31-bit value identifying the stream. +type StreamId uint32 + +// SynStreamFrame is the unpacked, in-memory representation of a SYN_STREAM +// frame. +type SynStreamFrame struct { + CFHeader ControlFrameHeader + StreamId StreamId + AssociatedToStreamId StreamId // stream id for a stream which this stream is associated to + Priority uint8 // priority of this frame (3-bit) + Slot uint8 // index in the server's credential vector of the client certificate + Headers http.Header +} + +// SynReplyFrame is the unpacked, in-memory representation of a SYN_REPLY frame. +type SynReplyFrame struct { + CFHeader ControlFrameHeader + StreamId StreamId + Headers http.Header +} + +// RstStreamStatus represents the status that led to a RST_STREAM. +type RstStreamStatus uint32 + +const ( + ProtocolError RstStreamStatus = iota + 1 + InvalidStream + RefusedStream + UnsupportedVersion + Cancel + InternalError + FlowControlError + StreamInUse + StreamAlreadyClosed + InvalidCredentials + FrameTooLarge +) + +// RstStreamFrame is the unpacked, in-memory representation of a RST_STREAM +// frame. +type RstStreamFrame struct { + CFHeader ControlFrameHeader + StreamId StreamId + Status RstStreamStatus +} + +// SettingsFlag represents a flag in a SETTINGS frame. +type SettingsFlag uint8 + +const ( + FlagSettingsPersistValue SettingsFlag = 0x1 + FlagSettingsPersisted SettingsFlag = 0x2 +) + +// SettingsFlag represents the id of an id/value pair in a SETTINGS frame. +type SettingsId uint32 + +const ( + SettingsUploadBandwidth SettingsId = iota + 1 + SettingsDownloadBandwidth + SettingsRoundTripTime + SettingsMaxConcurrentStreams + SettingsCurrentCwnd + SettingsDownloadRetransRate + SettingsInitialWindowSize + SettingsClientCretificateVectorSize +) + +// SettingsFlagIdValue is the unpacked, in-memory representation of the +// combined flag/id/value for a setting in a SETTINGS frame. +type SettingsFlagIdValue struct { + Flag SettingsFlag + Id SettingsId + Value uint32 +} + +// SettingsFrame is the unpacked, in-memory representation of a SPDY +// SETTINGS frame. +type SettingsFrame struct { + CFHeader ControlFrameHeader + FlagIdValues []SettingsFlagIdValue +} + +// PingFrame is the unpacked, in-memory representation of a PING frame. +type PingFrame struct { + CFHeader ControlFrameHeader + Id uint32 // unique id for this ping, from server is even, from client is odd. +} + +// GoAwayStatus represents the status in a GoAwayFrame. +type GoAwayStatus uint32 + +const ( + GoAwayOK GoAwayStatus = iota + GoAwayProtocolError + GoAwayInternalError +) + +// GoAwayFrame is the unpacked, in-memory representation of a GOAWAY frame. +type GoAwayFrame struct { + CFHeader ControlFrameHeader + LastGoodStreamId StreamId // last stream id which was accepted by sender + Status GoAwayStatus +} + +// HeadersFrame is the unpacked, in-memory representation of a HEADERS frame. +type HeadersFrame struct { + CFHeader ControlFrameHeader + StreamId StreamId + Headers http.Header +} + +// WindowUpdateFrame is the unpacked, in-memory representation of a +// WINDOW_UPDATE frame. +type WindowUpdateFrame struct { + CFHeader ControlFrameHeader + StreamId StreamId + DeltaWindowSize uint32 // additional number of bytes to existing window size +} + +// TODO: Implement credential frame and related methods. + +// DataFrame is the unpacked, in-memory representation of a DATA frame. +type DataFrame struct { + // Note, high bit is the "Control" bit. Should be 0 for data frames. + StreamId StreamId + Flags DataFlags + Data []byte // payload data of this frame +} + +// A SPDY specific error. +type ErrorCode string + +const ( + UnlowercasedHeaderName ErrorCode = "header was not lowercased" + DuplicateHeaders ErrorCode = "multiple headers with same name" + WrongCompressedPayloadSize ErrorCode = "compressed payload size was incorrect" + UnknownFrameType ErrorCode = "unknown frame type" + InvalidControlFrame ErrorCode = "invalid control frame" + InvalidDataFrame ErrorCode = "invalid data frame" + InvalidHeaderPresent ErrorCode = "frame contained invalid header" + ZeroStreamId ErrorCode = "stream id zero is disallowed" +) + +// Error contains both the type of error and additional values. StreamId is 0 +// if Error is not associated with a stream. +type Error struct { + Err ErrorCode + StreamId StreamId +} + +func (e *Error) Error() string { + return string(e.Err) +} + +var invalidReqHeaders = map[string]bool{ + "Connection": true, + "Host": true, + "Keep-Alive": true, + "Proxy-Connection": true, + "Transfer-Encoding": true, +} + +var invalidRespHeaders = map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Connection": true, + "Transfer-Encoding": true, +} + +// Framer handles serializing/deserializing SPDY frames, including compressing/ +// decompressing payloads. +type Framer struct { + headerCompressionDisabled bool + w io.Writer + headerBuf *bytes.Buffer + headerCompressor *zlib.Writer + r io.Reader + headerReader io.LimitedReader + headerDecompressor io.ReadCloser + + maxFrameLength uint32 // overrides the default frame payload length limit. + maxHeaderFieldSize uint32 // overrides the default per-header name/value length limit. + maxHeaderCount uint32 // overrides the default header count limit. +} + +// NewFramer allocates a new Framer for a given SPDY connection, represented by +// a io.Writer and io.Reader. Note that Framer will read and write individual fields +// from/to the Reader and Writer, so the caller should pass in an appropriately +// buffered implementation to optimize performance. +func NewFramer(w io.Writer, r io.Reader) (*Framer, error) { + return newFramer(w, r) +} + +// NewFramerWithOptions allocates a new Framer for a given SPDY connection and +// applies frame parsing limits via options. +func NewFramerWithOptions(w io.Writer, r io.Reader, opts ...FramerOption) (*Framer, error) { + return newFramer(w, r, opts...) +} + +func newFramer(w io.Writer, r io.Reader, opts ...FramerOption) (*Framer, error) { + compressBuf := new(bytes.Buffer) + compressor, err := zlib.NewWriterLevelDict(compressBuf, zlib.BestCompression, []byte(headerDictionary)) + if err != nil { + return nil, err + } + framer := &Framer{ + w: w, + headerBuf: compressBuf, + headerCompressor: compressor, + r: r, + } + for _, opt := range opts { + if opt != nil { + opt(framer) + } + } + return framer, nil +} diff --git a/vendor/github.com/moby/spdystream/spdy/write.go b/vendor/github.com/moby/spdystream/spdy/write.go new file mode 100644 index 000000000..75084d35d --- /dev/null +++ b/vendor/github.com/moby/spdystream/spdy/write.go @@ -0,0 +1,355 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "encoding/binary" + "io" + "math" + "net/http" + "strings" +) + +func (frame *SynStreamFrame) write(f *Framer) error { + return f.writeSynStreamFrame(frame) +} + +func (frame *SynReplyFrame) write(f *Framer) error { + return f.writeSynReplyFrame(frame) +} + +func (frame *RstStreamFrame) write(f *Framer) (err error) { + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeRstStream + frame.CFHeader.Flags = 0 + frame.CFHeader.length = 8 + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if frame.Status == 0 { + return &Error{InvalidControlFrame, frame.StreamId} + } + if err = binary.Write(f.w, binary.BigEndian, frame.Status); err != nil { + return + } + return +} + +func (frame *SettingsFrame) write(f *Framer) (err error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSettings + payloadLen := len(frame.FlagIdValues)*8 + 4 + if payloadLen > MaxDataLength { + return &Error{InvalidControlFrame, 0} + } + frame.CFHeader.length = uint32(payloadLen) + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + n := len(frame.FlagIdValues) + if uint64(n) > math.MaxUint32 { + return &Error{InvalidControlFrame, 0} + } + if err = binary.Write(f.w, binary.BigEndian, uint32(n)); err != nil { + return + } + for _, flagIdValue := range frame.FlagIdValues { + flagId := uint32(flagIdValue.Flag)<<24 | uint32(flagIdValue.Id) + if err = binary.Write(f.w, binary.BigEndian, flagId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, flagIdValue.Value); err != nil { + return + } + } + return +} + +func (frame *PingFrame) write(f *Framer) (err error) { + if frame.Id == 0 { + return &Error{ZeroStreamId, 0} + } + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypePing + frame.CFHeader.Flags = 0 + frame.CFHeader.length = 4 + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.Id); err != nil { + return + } + return +} + +func (frame *GoAwayFrame) write(f *Framer) (err error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeGoAway + frame.CFHeader.Flags = 0 + frame.CFHeader.length = 8 + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.LastGoodStreamId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.Status); err != nil { + return + } + return nil +} + +func (frame *HeadersFrame) write(f *Framer) error { + return f.writeHeadersFrame(frame) +} + +func (frame *WindowUpdateFrame) write(f *Framer) (err error) { + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeWindowUpdate + frame.CFHeader.Flags = 0 + frame.CFHeader.length = 8 + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.DeltaWindowSize); err != nil { + return + } + return nil +} + +func (frame *DataFrame) write(f *Framer) error { + return f.writeDataFrame(frame) +} + +// WriteFrame writes a frame. +func (f *Framer) WriteFrame(frame Frame) error { + return frame.write(f) +} + +func writeControlFrameHeader(w io.Writer, h ControlFrameHeader) error { + if err := binary.Write(w, binary.BigEndian, 0x8000|h.version); err != nil { + return err + } + if err := binary.Write(w, binary.BigEndian, h.frameType); err != nil { + return err + } + flagsAndLength := uint32(h.Flags)<<24 | h.length + if err := binary.Write(w, binary.BigEndian, flagsAndLength); err != nil { + return err + } + return nil +} + +func writeHeaderValueBlock(w io.Writer, h http.Header) (n int, err error) { + n = 0 + numHeaders := len(h) + if numHeaders > math.MaxInt32 { + return n, &Error{InvalidControlFrame, 0} + } + if err = binary.Write(w, binary.BigEndian, uint32(numHeaders)); err != nil { + return + } + n += 4 + for name, values := range h { + nameLen := len(name) + if nameLen > math.MaxInt32 { + return n, &Error{InvalidControlFrame, 0} + } + if err = binary.Write(w, binary.BigEndian, uint32(nameLen)); err != nil { + return + } + n += 4 + name = strings.ToLower(name) + if _, err = io.WriteString(w, name); err != nil { + return + } + n += nameLen + v := strings.Join(values, headerValueSeparator) + vLen := len(v) + if vLen > math.MaxInt32 { + return n, &Error{InvalidControlFrame, 0} + } + if err = binary.Write(w, binary.BigEndian, uint32(vLen)); err != nil { + return + } + n += 4 + if _, err = io.WriteString(w, v); err != nil { + return + } + n += vLen + } + return +} + +func (f *Framer) writeSynStreamFrame(frame *SynStreamFrame) (err error) { + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader. + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSynStream + hLen := len(f.headerBuf.Bytes()) + 10 + if hLen > MaxDataLength { + return &Error{InvalidControlFrame, 0} + } + frame.CFHeader.length = uint32(hLen) + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.AssociatedToStreamId); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.Priority<<5); err != nil { + return err + } + if err = binary.Write(f.w, binary.BigEndian, frame.Slot); err != nil { + return err + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return err + } + f.headerBuf.Reset() + return nil +} + +func (f *Framer) writeSynReplyFrame(frame *SynReplyFrame) (err error) { + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader. + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeSynReply + hLen := len(f.headerBuf.Bytes()) + 4 + if hLen > MaxDataLength { + return &Error{InvalidControlFrame, 0} + } + frame.CFHeader.length = uint32(hLen) + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return + } + f.headerBuf.Reset() + return +} + +func (f *Framer) writeHeadersFrame(frame *HeadersFrame) (err error) { + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + // Marshal the headers. + var writer io.Writer = f.headerBuf + if !f.headerCompressionDisabled { + writer = f.headerCompressor + } + if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { + return + } + if !f.headerCompressionDisabled { + f.headerCompressor.Flush() + } + + // Set ControlFrameHeader. + frame.CFHeader.version = Version + frame.CFHeader.frameType = TypeHeaders + hLen := len(f.headerBuf.Bytes()) + 4 + if hLen > MaxDataLength { + return &Error{InvalidControlFrame, 0} + } + frame.CFHeader.length = uint32(hLen) + + // Serialize frame to Writer. + if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { + return + } + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { + return + } + f.headerBuf.Reset() + return +} + +func (f *Framer) writeDataFrame(frame *DataFrame) (err error) { + if frame.StreamId == 0 { + return &Error{ZeroStreamId, 0} + } + if frame.StreamId&0x80000000 != 0 || len(frame.Data) > MaxDataLength { + return &Error{InvalidDataFrame, frame.StreamId} + } + + // Serialize frame to Writer. + if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { + return + } + dLen := len(frame.Data) + if dLen > MaxDataLength { + return &Error{InvalidDataFrame, frame.StreamId} + } + flagsAndLength := uint32(frame.Flags)<<24 | uint32(dLen) + if err = binary.Write(f.w, binary.BigEndian, flagsAndLength); err != nil { + return + } + if _, err = f.w.Write(frame.Data); err != nil { + return + } + return nil +} diff --git a/vendor/github.com/moby/spdystream/stream.go b/vendor/github.com/moby/spdystream/stream.go new file mode 100644 index 000000000..171c1e9e3 --- /dev/null +++ b/vendor/github.com/moby/spdystream/stream.go @@ -0,0 +1,345 @@ +/* + Copyright 2014-2021 Docker Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package spdystream + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "github.com/moby/spdystream/spdy" +) + +var ( + ErrUnreadPartialData = errors.New("unread partial data") +) + +type Stream struct { + streamId spdy.StreamId + parent *Stream + conn *Connection + startChan chan error + + dataLock sync.RWMutex + dataChan chan []byte + unread []byte + + priority uint8 + headers http.Header + headerChan chan http.Header + finishLock sync.Mutex + finished bool + replyCond *sync.Cond + replied bool + closeLock sync.Mutex + closeChan chan bool +} + +// WriteData writes data to stream, sending a dataframe per call +func (s *Stream) WriteData(data []byte, fin bool) error { + s.waitWriteReply() + var flags spdy.DataFlags + + if fin { + flags = spdy.DataFlagFin + s.finishLock.Lock() + if s.finished { + s.finishLock.Unlock() + return ErrWriteClosedStream + } + s.finished = true + s.finishLock.Unlock() + } + + dataFrame := &spdy.DataFrame{ + StreamId: s.streamId, + Flags: flags, + Data: data, + } + + debugMessage("(%p) (%d) Writing data frame", s, s.streamId) + return s.conn.framer.WriteFrame(dataFrame) +} + +// Write writes bytes to a stream, calling write data for each call. +func (s *Stream) Write(data []byte) (n int, err error) { + err = s.WriteData(data, false) + if err == nil { + n = len(data) + } + return +} + +// Read reads bytes from a stream, a single read will never get more +// than what is sent on a single data frame, but a multiple calls to +// read may get data from the same data frame. +func (s *Stream) Read(p []byte) (n int, err error) { + if s.unread == nil { + select { + case <-s.closeChan: + return 0, io.EOF + case read, ok := <-s.dataChan: + if !ok { + return 0, io.EOF + } + s.unread = read + } + } + n = copy(p, s.unread) + if n < len(s.unread) { + s.unread = s.unread[n:] + } else { + s.unread = nil + } + return +} + +// ReadData reads an entire data frame and returns the byte array +// from the data frame. If there is unread data from the result +// of a Read call, this function will return an ErrUnreadPartialData. +func (s *Stream) ReadData() ([]byte, error) { + debugMessage("(%p) Reading data from %d", s, s.streamId) + if s.unread != nil { + return nil, ErrUnreadPartialData + } + select { + case <-s.closeChan: + return nil, io.EOF + case read, ok := <-s.dataChan: + if !ok { + return nil, io.EOF + } + return read, nil + } +} + +func (s *Stream) waitWriteReply() { + if s.replyCond != nil { + s.replyCond.L.Lock() + for !s.replied { + s.replyCond.Wait() + } + s.replyCond.L.Unlock() + } +} + +// Wait waits for the stream to receive a reply. +func (s *Stream) Wait() error { + return s.WaitTimeout(time.Duration(0)) +} + +// WaitTimeout waits for the stream to receive a reply or for timeout. +// When the timeout is reached, ErrTimeout will be returned. +func (s *Stream) WaitTimeout(timeout time.Duration) error { + var timeoutChan <-chan time.Time + if timeout > time.Duration(0) { + timeoutChan = time.After(timeout) + } + + select { + case err := <-s.startChan: + if err != nil { + return err + } + break + case <-timeoutChan: + return ErrTimeout + } + return nil +} + +// Close closes the stream by sending an empty data frame with the +// finish flag set, indicating this side is finished with the stream. +func (s *Stream) Close() error { + select { + case <-s.closeChan: + // Stream is now fully closed + s.conn.removeStream(s) + default: + break + } + return s.WriteData([]byte{}, true) +} + +// Reset sends a reset frame, putting the stream into the fully closed state. +func (s *Stream) Reset() error { + s.conn.removeStream(s) + return s.resetStream() +} + +func (s *Stream) resetStream() error { + // Always call closeRemoteChannels, even if s.finished is already true. + // This makes it so that stream.Close() followed by stream.Reset() allows + // stream.Read() to unblock. + s.closeRemoteChannels() + + s.finishLock.Lock() + if s.finished { + s.finishLock.Unlock() + return nil + } + s.finished = true + s.finishLock.Unlock() + + resetFrame := &spdy.RstStreamFrame{ + StreamId: s.streamId, + Status: spdy.Cancel, + } + return s.conn.framer.WriteFrame(resetFrame) +} + +// CreateSubStream creates a stream using the current as the parent +func (s *Stream) CreateSubStream(headers http.Header, fin bool) (*Stream, error) { + return s.conn.CreateStream(headers, s, fin) +} + +// SetPriority sets the stream priority, does not affect the +// remote priority of this stream after Open has been called. +// Valid values are 0 through 7, 0 being the highest priority +// and 7 the lowest. +func (s *Stream) SetPriority(priority uint8) { + s.priority = priority +} + +// SendHeader sends a header frame across the stream +func (s *Stream) SendHeader(headers http.Header, fin bool) error { + return s.conn.sendHeaders(headers, s, fin) +} + +// SendReply sends a reply on a stream, only valid to be called once +// when handling a new stream +func (s *Stream) SendReply(headers http.Header, fin bool) error { + if s.replyCond == nil { + return errors.New("cannot reply on initiated stream") + } + s.replyCond.L.Lock() + defer s.replyCond.L.Unlock() + if s.replied { + return nil + } + + err := s.conn.sendReply(headers, s, fin) + if err != nil { + return err + } + + s.replied = true + s.replyCond.Broadcast() + return nil +} + +// Refuse sends a reset frame with the status refuse, only +// valid to be called once when handling a new stream. This +// may be used to indicate that a stream is not allowed +// when http status codes are not being used. +func (s *Stream) Refuse() error { + if s.replied { + return nil + } + s.replied = true + return s.conn.sendReset(spdy.RefusedStream, s) +} + +// Cancel sends a reset frame with the status canceled. This +// can be used at any time by the creator of the Stream to +// indicate the stream is no longer needed. +func (s *Stream) Cancel() error { + return s.conn.sendReset(spdy.Cancel, s) +} + +// ReceiveHeader receives a header sent on the other side +// of the stream. This function will block until a header +// is received or stream is closed. +func (s *Stream) ReceiveHeader() (http.Header, error) { + select { + case <-s.closeChan: + break + case header, ok := <-s.headerChan: + if !ok { + return nil, fmt.Errorf("header chan closed") + } + return header, nil + } + return nil, fmt.Errorf("stream closed") +} + +// Parent returns the parent stream +func (s *Stream) Parent() *Stream { + return s.parent +} + +// Headers returns the headers used to create the stream +func (s *Stream) Headers() http.Header { + return s.headers +} + +// String returns the string version of stream using the +// streamId to uniquely identify the stream +func (s *Stream) String() string { + return fmt.Sprintf("stream:%d", s.streamId) +} + +// Identifier returns a 32 bit identifier for the stream +func (s *Stream) Identifier() uint32 { + return uint32(s.streamId) +} + +// IsFinished returns whether the stream has finished +// sending data +func (s *Stream) IsFinished() bool { + s.finishLock.Lock() + defer s.finishLock.Unlock() + return s.finished +} + +// Implement net.Conn interface + +func (s *Stream) LocalAddr() net.Addr { + return s.conn.conn.LocalAddr() +} + +func (s *Stream) RemoteAddr() net.Addr { + return s.conn.conn.RemoteAddr() +} + +// TODO set per stream values instead of connection-wide + +func (s *Stream) SetDeadline(t time.Time) error { + return s.conn.conn.SetDeadline(t) +} + +func (s *Stream) SetReadDeadline(t time.Time) error { + return s.conn.conn.SetReadDeadline(t) +} + +func (s *Stream) SetWriteDeadline(t time.Time) error { + return s.conn.conn.SetWriteDeadline(t) +} + +func (s *Stream) closeRemoteChannels() { + s.closeLock.Lock() + defer s.closeLock.Unlock() + select { + case <-s.closeChan: + default: + close(s.closeChan) + } +} diff --git a/vendor/github.com/moby/spdystream/utils.go b/vendor/github.com/moby/spdystream/utils.go new file mode 100644 index 000000000..e9f7fffd6 --- /dev/null +++ b/vendor/github.com/moby/spdystream/utils.go @@ -0,0 +1,32 @@ +/* + Copyright 2014-2021 Docker Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package spdystream + +import ( + "log" + "os" +) + +var ( + DEBUG = os.Getenv("DEBUG") +) + +func debugMessage(fmt string, args ...interface{}) { + if DEBUG != "" { + log.Printf(fmt, args...) + } +} diff --git a/vendor/golang.org/x/net/internal/socks/client.go b/vendor/golang.org/x/net/internal/socks/client.go new file mode 100644 index 000000000..3d6f516a5 --- /dev/null +++ b/vendor/golang.org/x/net/internal/socks/client.go @@ -0,0 +1,168 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "time" +) + +var ( + noDeadline = time.Time{} + aLongTimeAgo = time.Unix(1, 0) +) + +func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { + host, port, err := splitHostPort(address) + if err != nil { + return nil, err + } + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + c.SetDeadline(deadline) + defer c.SetDeadline(noDeadline) + } + if ctx != context.Background() { + errCh := make(chan error, 1) + done := make(chan struct{}) + defer func() { + close(done) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + go func() { + select { + case <-ctx.Done(): + c.SetDeadline(aLongTimeAgo) + errCh <- ctx.Err() + case <-done: + errCh <- nil + } + }() + } + + b := make([]byte, 0, 6+len(host)) // the size here is just an estimate + b = append(b, Version5) + if len(d.AuthMethods) == 0 || d.Authenticate == nil { + b = append(b, 1, byte(AuthMethodNotRequired)) + } else { + ams := d.AuthMethods + if len(ams) > 255 { + return nil, errors.New("too many authentication methods") + } + b = append(b, byte(len(ams))) + for _, am := range ams { + b = append(b, byte(am)) + } + } + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + am := AuthMethod(b[1]) + if am == AuthMethodNoAcceptableMethods { + return nil, errors.New("no acceptable authentication methods") + } + if d.Authenticate != nil { + if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { + return + } + } + + b = b[:0] + b = append(b, Version5, byte(d.cmd), 0) + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + b = append(b, AddrTypeIPv4) + b = append(b, ip4...) + } else if ip6 := ip.To16(); ip6 != nil { + b = append(b, AddrTypeIPv6) + b = append(b, ip6...) + } else { + return nil, errors.New("unknown address type") + } + } else { + if len(host) > 255 { + return nil, errors.New("FQDN too long") + } + b = append(b, AddrTypeFQDN) + b = append(b, byte(len(host))) + b = append(b, host...) + } + b = append(b, byte(port>>8), byte(port)) + if _, ctxErr = c.Write(b); ctxErr != nil { + return + } + + if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { + return + } + if b[0] != Version5 { + return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + } + if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { + return nil, errors.New("unknown error " + cmdErr.String()) + } + if b[2] != 0 { + return nil, errors.New("non-zero reserved field") + } + l := 2 + var a Addr + switch b[3] { + case AddrTypeIPv4: + l += net.IPv4len + a.IP = make(net.IP, net.IPv4len) + case AddrTypeIPv6: + l += net.IPv6len + a.IP = make(net.IP, net.IPv6len) + case AddrTypeFQDN: + if _, err := io.ReadFull(c, b[:1]); err != nil { + return nil, err + } + l += int(b[0]) + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + } + if cap(b) < l { + b = make([]byte, l) + } else { + b = b[:l] + } + if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { + return + } + if a.IP != nil { + copy(a.IP, b) + } else { + a.Name = string(b[:len(b)-2]) + } + a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) + return &a, nil +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} diff --git a/vendor/golang.org/x/net/internal/socks/socks.go b/vendor/golang.org/x/net/internal/socks/socks.go new file mode 100644 index 000000000..8eedb84ce --- /dev/null +++ b/vendor/golang.org/x/net/internal/socks/socks.go @@ -0,0 +1,317 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package socks provides a SOCKS version 5 client implementation. +// +// SOCKS protocol version 5 is defined in RFC 1928. +// Username/Password authentication for SOCKS version 5 is defined in +// RFC 1929. +package socks + +import ( + "context" + "errors" + "io" + "net" + "strconv" +) + +// A Command represents a SOCKS command. +type Command int + +func (cmd Command) String() string { + switch cmd { + case CmdConnect: + return "socks connect" + case cmdBind: + return "socks bind" + default: + return "socks " + strconv.Itoa(int(cmd)) + } +} + +// An AuthMethod represents a SOCKS authentication method. +type AuthMethod int + +// A Reply represents a SOCKS command reply code. +type Reply int + +func (code Reply) String() string { + switch code { + case StatusSucceeded: + return "succeeded" + case 0x01: + return "general SOCKS server failure" + case 0x02: + return "connection not allowed by ruleset" + case 0x03: + return "network unreachable" + case 0x04: + return "host unreachable" + case 0x05: + return "connection refused" + case 0x06: + return "TTL expired" + case 0x07: + return "command not supported" + case 0x08: + return "address type not supported" + default: + return "unknown code: " + strconv.Itoa(int(code)) + } +} + +// Wire protocol constants. +const ( + Version5 = 0x05 + + AddrTypeIPv4 = 0x01 + AddrTypeFQDN = 0x03 + AddrTypeIPv6 = 0x04 + + CmdConnect Command = 0x01 // establishes an active-open forward proxy connection + cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + + AuthMethodNotRequired AuthMethod = 0x00 // no authentication required + AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password + AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods + + StatusSucceeded Reply = 0x00 +) + +// An Addr represents a SOCKS-specific address. +// Either Name or IP is used exclusively. +type Addr struct { + Name string // fully-qualified domain name + IP net.IP + Port int +} + +func (a *Addr) Network() string { return "socks" } + +func (a *Addr) String() string { + if a == nil { + return "" + } + port := strconv.Itoa(a.Port) + if a.IP == nil { + return net.JoinHostPort(a.Name, port) + } + return net.JoinHostPort(a.IP.String(), port) +} + +// A Conn represents a forward proxy connection. +type Conn struct { + net.Conn + + boundAddr net.Addr +} + +// BoundAddr returns the address assigned by the proxy server for +// connecting to the command target address from the proxy server. +func (c *Conn) BoundAddr() net.Addr { + if c == nil { + return nil + } + return c.boundAddr +} + +// A Dialer holds SOCKS-specific options. +type Dialer struct { + cmd Command // either CmdConnect or cmdBind + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address + + // ProxyDial specifies the optional dial function for + // establishing the transport connection. + ProxyDial func(context.Context, string, string) (net.Conn, error) + + // AuthMethods specifies the list of request authentication + // methods. + // If empty, SOCKS client requests only AuthMethodNotRequired. + AuthMethods []AuthMethod + + // Authenticate specifies the optional authentication + // function. It must be non-nil when AuthMethods is not empty. + // It must return an error when the authentication is failed. + Authenticate func(context.Context, io.ReadWriter, AuthMethod) error +} + +// DialContext connects to the provided address on the provided +// network. +// +// The returned error value may be a net.OpError. When the Op field of +// net.OpError contains "socks", the Source field contains a proxy +// server address and the Addr field contains a command target +// address. +// +// See func Dial of the net package of standard library for a +// description of the network and address parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) + } else { + var dd net.Dialer + c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + a, err := d.connect(ctx, c, address) + if err != nil { + c.Close() + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return &Conn{Conn: c, boundAddr: a}, nil +} + +// DialWithConn initiates a connection from SOCKS server to the target +// network and address using the connection c that is already +// connected to the SOCKS server. +// +// It returns the connection's local address assigned by the SOCKS +// server. +func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if ctx == nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + } + a, err := d.connect(ctx, c, address) + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + return a, nil +} + +// Dial connects to the provided address on the provided network. +// +// Unlike DialContext, it returns a raw transport connection instead +// of a forward proxy connection. +// +// Deprecated: Use DialContext or DialWithConn instead. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + if err := d.validateTarget(network, address); err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + var err error + var c net.Conn + if d.ProxyDial != nil { + c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) + } else { + c, err = net.Dial(d.proxyNetwork, d.proxyAddress) + } + if err != nil { + proxy, dst, _ := d.pathAddrs(address) + return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + } + if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { + c.Close() + return nil, err + } + return c, nil +} + +func (d *Dialer) validateTarget(network, address string) error { + switch network { + case "tcp", "tcp6", "tcp4": + default: + return errors.New("network not implemented") + } + switch d.cmd { + case CmdConnect, cmdBind: + default: + return errors.New("command not implemented") + } + return nil +} + +func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { + for i, s := range []string{d.proxyAddress, address} { + host, port, err := splitHostPort(s) + if err != nil { + return nil, nil, err + } + a := &Addr{Port: port} + a.IP = net.ParseIP(host) + if a.IP == nil { + a.Name = host + } + if i == 0 { + proxy = a + } else { + dst = a + } + } + return +} + +// NewDialer returns a new Dialer that dials through the provided +// proxy server's network and address. +func NewDialer(network, address string) *Dialer { + return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect} +} + +const ( + authUsernamePasswordVersion = 0x01 + authStatusSucceeded = 0x00 +) + +// UsernamePassword are the credentials for the username/password +// authentication method. +type UsernamePassword struct { + Username string + Password string +} + +// Authenticate authenticates a pair of username and password with the +// proxy server. +func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error { + switch auth { + case AuthMethodNotRequired: + return nil + case AuthMethodUsernamePassword: + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 { + return errors.New("invalid username/password") + } + b := []byte{authUsernamePasswordVersion} + b = append(b, byte(len(up.Username))) + b = append(b, up.Username...) + b = append(b, byte(len(up.Password))) + b = append(b, up.Password...) + // TODO(mikio): handle IO deadlines and cancellation if + // necessary + if _, err := rw.Write(b); err != nil { + return err + } + if _, err := io.ReadFull(rw, b[:2]); err != nil { + return err + } + if b[0] != authUsernamePasswordVersion { + return errors.New("invalid username/password version") + } + if b[1] != authStatusSucceeded { + return errors.New("username/password authentication failed") + } + return nil + } + return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) +} diff --git a/vendor/golang.org/x/net/proxy/dial.go b/vendor/golang.org/x/net/proxy/dial.go new file mode 100644 index 000000000..811c2e4e9 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/dial.go @@ -0,0 +1,54 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" +) + +// A ContextDialer dials using a context. +type ContextDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment. +// +// The passed ctx is only used for returning the Conn, not the lifetime of the Conn. +// +// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer +// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout. +// +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func Dial(ctx context.Context, network, address string) (net.Conn, error) { + d := FromEnvironment() + if xd, ok := d.(ContextDialer); ok { + return xd.DialContext(ctx, network, address) + } + return dialContext(ctx, d, network, address) +} + +// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout +// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed. +func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) { + var ( + conn net.Conn + done = make(chan struct{}, 1) + err error + ) + go func() { + conn, err = d.Dial(network, address) + close(done) + if conn != nil && ctx.Err() != nil { + conn.Close() + } + }() + select { + case <-ctx.Done(): + err = ctx.Err() + case <-done: + } + return conn, err +} diff --git a/vendor/golang.org/x/net/proxy/direct.go b/vendor/golang.org/x/net/proxy/direct.go new file mode 100644 index 000000000..3d66bdef9 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/direct.go @@ -0,0 +1,31 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" +) + +type direct struct{} + +// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext. +var Direct = direct{} + +var ( + _ Dialer = Direct + _ ContextDialer = Direct +) + +// Dial directly invokes net.Dial with the supplied parameters. +func (direct) Dial(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) +} + +// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters. +func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) +} diff --git a/vendor/golang.org/x/net/proxy/per_host.go b/vendor/golang.org/x/net/proxy/per_host.go new file mode 100644 index 000000000..32bdf435e --- /dev/null +++ b/vendor/golang.org/x/net/proxy/per_host.go @@ -0,0 +1,153 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" + "net/netip" + "strings" +) + +// A PerHost directs connections to a default Dialer unless the host name +// requested matches one of a number of exceptions. +type PerHost struct { + def, bypass Dialer + + bypassNetworks []*net.IPNet + bypassIPs []net.IP + bypassZones []string + bypassHosts []string +} + +// NewPerHost returns a PerHost Dialer that directs connections to either +// defaultDialer or bypass, depending on whether the connection matches one of +// the configured rules. +func NewPerHost(defaultDialer, bypass Dialer) *PerHost { + return &PerHost{ + def: defaultDialer, + bypass: bypass, + } +} + +// Dial connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + return p.dialerForRequest(host).Dial(network, addr) +} + +// DialContext connects to the address addr on the given network through either +// defaultDialer or bypass. +func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + d := p.dialerForRequest(host) + if x, ok := d.(ContextDialer); ok { + return x.DialContext(ctx, network, addr) + } + return dialContext(ctx, d, network, addr) +} + +func (p *PerHost) dialerForRequest(host string) Dialer { + if nip, err := netip.ParseAddr(host); err == nil { + ip := net.IP(nip.AsSlice()) + for _, net := range p.bypassNetworks { + if net.Contains(ip) { + return p.bypass + } + } + for _, bypassIP := range p.bypassIPs { + if bypassIP.Equal(ip) { + return p.bypass + } + } + return p.def + } + + for _, zone := range p.bypassZones { + if strings.HasSuffix(host, zone) { + return p.bypass + } + if host == zone[1:] { + // For a zone ".example.com", we match "example.com" + // too. + return p.bypass + } + } + for _, bypassHost := range p.bypassHosts { + if bypassHost == host { + return p.bypass + } + } + return p.def +} + +// AddFromString parses a string that contains comma-separated values +// specifying hosts that should use the bypass proxy. Each value is either an +// IP address, a CIDR range, a zone (*.example.com) or a host name +// (localhost). A best effort is made to parse the string and errors are +// ignored. +func (p *PerHost) AddFromString(s string) { + hosts := strings.Split(s, ",") + for _, host := range hosts { + host = strings.TrimSpace(host) + if len(host) == 0 { + continue + } + if strings.Contains(host, "/") { + // We assume that it's a CIDR address like 127.0.0.0/8 + if _, net, err := net.ParseCIDR(host); err == nil { + p.AddNetwork(net) + } + continue + } + if nip, err := netip.ParseAddr(host); err == nil { + p.AddIP(net.IP(nip.AsSlice())) + continue + } + if strings.HasPrefix(host, "*.") { + p.AddZone(host[1:]) + continue + } + p.AddHost(host) + } +} + +// AddIP specifies an IP address that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match an IP. +func (p *PerHost) AddIP(ip net.IP) { + p.bypassIPs = append(p.bypassIPs, ip) +} + +// AddNetwork specifies an IP range that will use the bypass proxy. Note that +// this will only take effect if a literal IP address is dialed. A connection +// to a named host will never match. +func (p *PerHost) AddNetwork(net *net.IPNet) { + p.bypassNetworks = append(p.bypassNetworks, net) +} + +// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of +// "example.com" matches "example.com" and all of its subdomains. +func (p *PerHost) AddZone(zone string) { + zone = strings.TrimSuffix(zone, ".") + if !strings.HasPrefix(zone, ".") { + zone = "." + zone + } + p.bypassZones = append(p.bypassZones, zone) +} + +// AddHost specifies a host name that will use the bypass proxy. +func (p *PerHost) AddHost(host string) { + host = strings.TrimSuffix(host, ".") + p.bypassHosts = append(p.bypassHosts, host) +} diff --git a/vendor/golang.org/x/net/proxy/proxy.go b/vendor/golang.org/x/net/proxy/proxy.go new file mode 100644 index 000000000..9ff4b9a77 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/proxy.go @@ -0,0 +1,149 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package proxy provides support for a variety of protocols to proxy network +// data. +package proxy // import "golang.org/x/net/proxy" + +import ( + "errors" + "net" + "net/url" + "os" + "sync" +) + +// A Dialer is a means to establish a connection. +// Custom dialers should also implement ContextDialer. +type Dialer interface { + // Dial connects to the given address via the proxy. + Dial(network, addr string) (c net.Conn, err error) +} + +// Auth contains authentication parameters that specific Dialers may require. +type Auth struct { + User, Password string +} + +// FromEnvironment returns the dialer specified by the proxy-related +// variables in the environment and makes underlying connections +// directly. +func FromEnvironment() Dialer { + return FromEnvironmentUsing(Direct) +} + +// FromEnvironmentUsing returns the dialer specify by the proxy-related +// variables in the environment and makes underlying connections +// using the provided forwarding Dialer (for instance, a *net.Dialer +// with desired configuration). +func FromEnvironmentUsing(forward Dialer) Dialer { + allProxy := allProxyEnv.Get() + if len(allProxy) == 0 { + return forward + } + + proxyURL, err := url.Parse(allProxy) + if err != nil { + return forward + } + proxy, err := FromURL(proxyURL, forward) + if err != nil { + return forward + } + + noProxy := noProxyEnv.Get() + if len(noProxy) == 0 { + return proxy + } + + perHost := NewPerHost(proxy, forward) + perHost.AddFromString(noProxy) + return perHost +} + +// proxySchemes is a map from URL schemes to a function that creates a Dialer +// from a URL with such a scheme. +var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error) + +// RegisterDialerType takes a URL scheme and a function to generate Dialers from +// a URL with that scheme and a forwarding Dialer. Registered schemes are used +// by FromURL. +func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) { + if proxySchemes == nil { + proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error)) + } + proxySchemes[scheme] = f +} + +// FromURL returns a Dialer given a URL specification and an underlying +// Dialer for it to make network requests. +func FromURL(u *url.URL, forward Dialer) (Dialer, error) { + var auth *Auth + if u.User != nil { + auth = new(Auth) + auth.User = u.User.Username() + if p, ok := u.User.Password(); ok { + auth.Password = p + } + } + + switch u.Scheme { + case "socks5", "socks5h": + addr := u.Hostname() + port := u.Port() + if port == "" { + port = "1080" + } + return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward) + } + + // If the scheme doesn't match any of the built-in schemes, see if it + // was registered by another package. + if proxySchemes != nil { + if f, ok := proxySchemes[u.Scheme]; ok { + return f(u, forward) + } + } + + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) +} + +var ( + allProxyEnv = &envOnce{ + names: []string{"ALL_PROXY", "all_proxy"}, + } + noProxyEnv = &envOnce{ + names: []string{"NO_PROXY", "no_proxy"}, + } +) + +// envOnce looks up an environment variable (optionally by multiple +// names) once. It mitigates expensive lookups on some platforms +// (e.g. Windows). +// (Borrowed from net/http/transport.go) +type envOnce struct { + names []string + once sync.Once + val string +} + +func (e *envOnce) Get() string { + e.once.Do(e.init) + return e.val +} + +func (e *envOnce) init() { + for _, n := range e.names { + e.val = os.Getenv(n) + if e.val != "" { + return + } + } +} + +// reset is used by tests +func (e *envOnce) reset() { + e.once = sync.Once{} + e.val = "" +} diff --git a/vendor/golang.org/x/net/proxy/socks5.go b/vendor/golang.org/x/net/proxy/socks5.go new file mode 100644 index 000000000..c91651f96 --- /dev/null +++ b/vendor/golang.org/x/net/proxy/socks5.go @@ -0,0 +1,42 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package proxy + +import ( + "context" + "net" + + "golang.org/x/net/internal/socks" +) + +// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given +// address with an optional username and password. +// See RFC 1928 and RFC 1929. +func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) { + d := socks.NewDialer(network, address) + if forward != nil { + if f, ok := forward.(ContextDialer); ok { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return f.DialContext(ctx, network, address) + } + } else { + d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) { + return dialContext(ctx, forward, network, address) + } + } + } + if auth != nil { + up := socks.UsernamePassword{ + Username: auth.User, + Password: auth.Password, + } + d.AuthMethods = []socks.AuthMethod{ + socks.AuthMethodNotRequired, + socks.AuthMethodUsernamePassword, + } + d.Authenticate = up.Authenticate + } + return d, nil +} diff --git a/vendor/golang.org/x/net/websocket/client.go b/vendor/golang.org/x/net/websocket/client.go new file mode 100644 index 000000000..1e64157f3 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/client.go @@ -0,0 +1,139 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "context" + "io" + "net" + "net/http" + "net/url" + "time" +) + +// DialError is an error that occurs while dialling a websocket server. +type DialError struct { + *Config + Err error +} + +func (e *DialError) Error() string { + return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error() +} + +// NewConfig creates a new WebSocket config for client connection. +func NewConfig(server, origin string) (config *Config, err error) { + config = new(Config) + config.Version = ProtocolVersionHybi13 + config.Location, err = url.ParseRequestURI(server) + if err != nil { + return + } + config.Origin, err = url.ParseRequestURI(origin) + if err != nil { + return + } + config.Header = http.Header(make(map[string][]string)) + return +} + +// NewClient creates a new WebSocket client connection over rwc. +func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + err = hybiClientHandshake(config, br, bw) + if err != nil { + return + } + buf := bufio.NewReadWriter(br, bw) + ws = newHybiClientConn(config, buf, rwc) + return +} + +// Dial opens a new client connection to a WebSocket. +func Dial(url_, protocol, origin string) (ws *Conn, err error) { + config, err := NewConfig(url_, origin) + if err != nil { + return nil, err + } + if protocol != "" { + config.Protocol = []string{protocol} + } + return DialConfig(config) +} + +var portMap = map[string]string{ + "ws": "80", + "wss": "443", +} + +func parseAuthority(location *url.URL) string { + if _, ok := portMap[location.Scheme]; ok { + if _, _, err := net.SplitHostPort(location.Host); err != nil { + return net.JoinHostPort(location.Host, portMap[location.Scheme]) + } + } + return location.Host +} + +// DialConfig opens a new client connection to a WebSocket with a config. +func DialConfig(config *Config) (ws *Conn, err error) { + return config.DialContext(context.Background()) +} + +// DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation. +func (config *Config) DialContext(ctx context.Context) (*Conn, error) { + if config.Location == nil { + return nil, &DialError{config, ErrBadWebSocketLocation} + } + if config.Origin == nil { + return nil, &DialError{config, ErrBadWebSocketOrigin} + } + + dialer := config.Dialer + if dialer == nil { + dialer = &net.Dialer{} + } + + client, err := dialWithDialer(ctx, dialer, config) + if err != nil { + return nil, &DialError{config, err} + } + + // Cleanup the connection if we fail to create the websocket successfully + success := false + defer func() { + if !success { + _ = client.Close() + } + }() + + var ws *Conn + var wsErr error + doneConnecting := make(chan struct{}) + go func() { + defer close(doneConnecting) + ws, err = NewClient(config, client) + if err != nil { + wsErr = &DialError{config, err} + } + }() + + // The websocket.NewClient() function can block indefinitely, make sure that we + // respect the deadlines specified by the context. + select { + case <-ctx.Done(): + // Force the pending operations to fail, terminating the pending connection attempt + _ = client.SetDeadline(time.Now()) + <-doneConnecting // Wait for the goroutine that tries to establish the connection to finish + return nil, &DialError{config, ctx.Err()} + case <-doneConnecting: + if wsErr == nil { + success = true // Disarm the deferred connection cleanup + } + return ws, wsErr + } +} diff --git a/vendor/golang.org/x/net/websocket/dial.go b/vendor/golang.org/x/net/websocket/dial.go new file mode 100644 index 000000000..8a2d83c47 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/dial.go @@ -0,0 +1,29 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "context" + "crypto/tls" + "net" +) + +func dialWithDialer(ctx context.Context, dialer *net.Dialer, config *Config) (conn net.Conn, err error) { + switch config.Location.Scheme { + case "ws": + conn, err = dialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) + + case "wss": + tlsDialer := &tls.Dialer{ + NetDialer: dialer, + Config: config.TlsConfig, + } + + conn, err = tlsDialer.DialContext(ctx, "tcp", parseAuthority(config.Location)) + default: + err = ErrBadScheme + } + return +} diff --git a/vendor/golang.org/x/net/websocket/hybi.go b/vendor/golang.org/x/net/websocket/hybi.go new file mode 100644 index 000000000..c7e76cd91 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/hybi.go @@ -0,0 +1,583 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +// This file implements a protocol of hybi draft. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +const ( + websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + closeStatusNormal = 1000 + closeStatusGoingAway = 1001 + closeStatusProtocolError = 1002 + closeStatusUnsupportedData = 1003 + closeStatusFrameTooLarge = 1004 + closeStatusNoStatusRcvd = 1005 + closeStatusAbnormalClosure = 1006 + closeStatusBadMessageData = 1007 + closeStatusPolicyViolation = 1008 + closeStatusTooBigData = 1009 + closeStatusExtensionMismatch = 1010 + + maxControlFramePayloadLength = 125 +) + +var ( + ErrBadMaskingKey = &ProtocolError{"bad masking key"} + ErrBadPongMessage = &ProtocolError{"bad pong message"} + ErrBadClosingStatus = &ProtocolError{"bad closing status"} + ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} + ErrNotImplemented = &ProtocolError{"not implemented"} + + handshakeHeader = map[string]bool{ + "Host": true, + "Upgrade": true, + "Connection": true, + "Sec-Websocket-Key": true, + "Sec-Websocket-Origin": true, + "Sec-Websocket-Version": true, + "Sec-Websocket-Protocol": true, + "Sec-Websocket-Accept": true, + } +) + +// A hybiFrameHeader is a frame header as defined in hybi draft. +type hybiFrameHeader struct { + Fin bool + Rsv [3]bool + OpCode byte + Length int64 + MaskingKey []byte + + data *bytes.Buffer +} + +// A hybiFrameReader is a reader for hybi frame. +type hybiFrameReader struct { + reader io.Reader + + header hybiFrameHeader + pos int64 + length int +} + +func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { + n, err = frame.reader.Read(msg) + if frame.header.MaskingKey != nil { + for i := 0; i < n; i++ { + msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] + frame.pos++ + } + } + return n, err +} + +func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } + +func (frame *hybiFrameReader) HeaderReader() io.Reader { + if frame.header.data == nil { + return nil + } + if frame.header.data.Len() == 0 { + return nil + } + return frame.header.data +} + +func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } + +func (frame *hybiFrameReader) Len() (n int) { return frame.length } + +// A hybiFrameReaderFactory creates new frame reader based on its frame type. +type hybiFrameReaderFactory struct { + *bufio.Reader +} + +// NewFrameReader reads a frame header from the connection, and creates new reader for the frame. +// See Section 5.2 Base Framing protocol for detail. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 +func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { + hybiFrame := new(hybiFrameReader) + frame = hybiFrame + var header []byte + var b byte + // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 + for i := 0; i < 3; i++ { + j := uint(6 - i) + hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 + } + hybiFrame.header.OpCode = header[0] & 0x0f + + // Second byte. Mask/Payload len(7bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + mask := (b & 0x80) != 0 + b &= 0x7f + lengthFields := 0 + switch { + case b <= 125: // Payload length 7bits. + hybiFrame.header.Length = int64(b) + case b == 126: // Payload length 7+16bits + lengthFields = 2 + case b == 127: // Payload length 7+64bits + lengthFields = 8 + } + for i := 0; i < lengthFields; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits + b &= 0x7f + } + header = append(header, b) + hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) + } + if mask { + // Masking key. 4 bytes. + for i := 0; i < 4; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) + } + } + hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) + hybiFrame.header.data = bytes.NewBuffer(header) + hybiFrame.length = len(header) + int(hybiFrame.header.Length) + return +} + +// A HybiFrameWriter is a writer for hybi frame. +type hybiFrameWriter struct { + writer *bufio.Writer + + header *hybiFrameHeader +} + +func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { + var header []byte + var b byte + if frame.header.Fin { + b |= 0x80 + } + for i := 0; i < 3; i++ { + if frame.header.Rsv[i] { + j := uint(6 - i) + b |= 1 << j + } + } + b |= frame.header.OpCode + header = append(header, b) + if frame.header.MaskingKey != nil { + b = 0x80 + } else { + b = 0 + } + lengthFields := 0 + length := len(msg) + switch { + case length <= 125: + b |= byte(length) + case length < 65536: + b |= 126 + lengthFields = 2 + default: + b |= 127 + lengthFields = 8 + } + header = append(header, b) + for i := 0; i < lengthFields; i++ { + j := uint((lengthFields - i - 1) * 8) + b = byte((length >> j) & 0xff) + header = append(header, b) + } + if frame.header.MaskingKey != nil { + if len(frame.header.MaskingKey) != 4 { + return 0, ErrBadMaskingKey + } + header = append(header, frame.header.MaskingKey...) + frame.writer.Write(header) + data := make([]byte, length) + for i := range data { + data[i] = msg[i] ^ frame.header.MaskingKey[i%4] + } + frame.writer.Write(data) + err = frame.writer.Flush() + return length, err + } + frame.writer.Write(header) + frame.writer.Write(msg) + err = frame.writer.Flush() + return length, err +} + +func (frame *hybiFrameWriter) Close() error { return nil } + +type hybiFrameWriterFactory struct { + *bufio.Writer + needMaskingKey bool +} + +func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} + if buf.needMaskingKey { + frameHeader.MaskingKey, err = generateMaskingKey() + if err != nil { + return nil, err + } + } + return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil +} + +type hybiFrameHandler struct { + conn *Conn + payloadType byte +} + +func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { + if handler.conn.IsServerConn() { + // The client MUST mask all frames sent to the server. + if frame.(*hybiFrameReader).header.MaskingKey == nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } else { + // The server MUST NOT mask all frames. + if frame.(*hybiFrameReader).header.MaskingKey != nil { + handler.WriteClose(closeStatusProtocolError) + return nil, io.EOF + } + } + if header := frame.HeaderReader(); header != nil { + io.Copy(io.Discard, header) + } + switch frame.PayloadType() { + case ContinuationFrame: + frame.(*hybiFrameReader).header.OpCode = handler.payloadType + case TextFrame, BinaryFrame: + handler.payloadType = frame.PayloadType() + case CloseFrame: + return nil, io.EOF + case PingFrame, PongFrame: + b := make([]byte, maxControlFramePayloadLength) + n, err := io.ReadFull(frame, b) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return nil, err + } + io.Copy(io.Discard, frame) + if frame.PayloadType() == PingFrame { + if _, err := handler.WritePong(b[:n]); err != nil { + return nil, err + } + } + return nil, nil + } + return frame, nil +} + +func (handler *hybiFrameHandler) WriteClose(status int) (err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) + if err != nil { + return err + } + msg := make([]byte, 2) + binary.BigEndian.PutUint16(msg, uint16(status)) + _, err = w.Write(msg) + w.Close() + return err +} + +func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// newHybiConn creates a new WebSocket connection speaking hybi draft protocol. +func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + if buf == nil { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + buf = bufio.NewReadWriter(br, bw) + } + ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, + frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, + frameWriterFactory: hybiFrameWriterFactory{ + buf.Writer, request == nil}, + PayloadType: TextFrame, + defaultCloseStatus: closeStatusNormal} + ws.frameHandler = &hybiFrameHandler{conn: ws} + return ws +} + +// generateMaskingKey generates a masking key for a frame. +func generateMaskingKey() (maskingKey []byte, err error) { + maskingKey = make([]byte, 4) + if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { + return + } + return +} + +// generateNonce generates a nonce consisting of a randomly selected 16-byte +// value that has been base64-encoded. +func generateNonce() (nonce []byte) { + key := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + panic(err) + } + nonce = make([]byte, 24) + base64.StdEncoding.Encode(nonce, key) + return +} + +// removeZone removes IPv6 zone identifier from host. +// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" +func removeZone(host string) string { + if !strings.HasPrefix(host, "[") { + return host + } + i := strings.LastIndex(host, "]") + if i < 0 { + return host + } + j := strings.LastIndex(host[:i], "%") + if j < 0 { + return host + } + return host[:j] + host[i:] +} + +// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of +// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. +func getNonceAccept(nonce []byte) (expected []byte, err error) { + h := sha1.New() + if _, err = h.Write(nonce); err != nil { + return + } + if _, err = h.Write([]byte(websocketGUID)); err != nil { + return + } + expected = make([]byte, 28) + base64.StdEncoding.Encode(expected, h.Sum(nil)) + return +} + +// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 +func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { + bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") + + // According to RFC 6874, an HTTP client, proxy, or other + // intermediary must remove any IPv6 zone identifier attached + // to an outgoing URI. + bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n") + bw.WriteString("Upgrade: websocket\r\n") + bw.WriteString("Connection: Upgrade\r\n") + nonce := generateNonce() + if config.handshakeData != nil { + nonce = []byte(config.handshakeData["key"]) + } + bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") + bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + + if config.Version != ProtocolVersionHybi13 { + return ErrBadProtocolVersion + } + + bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") + if len(config.Protocol) > 0 { + bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + err = config.Header.WriteSubset(bw, handshakeHeader) + if err != nil { + return err + } + + bw.WriteString("\r\n") + if err = bw.Flush(); err != nil { + return err + } + + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 101 { + return ErrBadStatus + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + return ErrBadUpgrade + } + expectedAccept, err := getNonceAccept(nonce) + if err != nil { + return err + } + if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { + return ErrChallengeResponse + } + if resp.Header.Get("Sec-WebSocket-Extensions") != "" { + return ErrUnsupportedExtensions + } + offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") + if offeredProtocol != "" { + protocolMatched := false + for i := 0; i < len(config.Protocol); i++ { + if config.Protocol[i] == offeredProtocol { + protocolMatched = true + break + } + } + if !protocolMatched { + return ErrBadWebSocketProtocol + } + config.Protocol = []string{offeredProtocol} + } + + return nil +} + +// newHybiClientConn creates a client WebSocket connection after handshake. +func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { + return newHybiConn(config, buf, rwc, nil) +} + +// A HybiServerHandshaker performs a server handshake using hybi draft protocol. +type hybiServerHandshaker struct { + *Config + accept []byte +} + +func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { + c.Version = ProtocolVersionHybi13 + if req.Method != "GET" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + // HTTP version can be safely ignored. + + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + return http.StatusBadRequest, ErrNotWebSocket + } + + key := req.Header.Get("Sec-Websocket-Key") + if key == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + version := req.Header.Get("Sec-Websocket-Version") + switch version { + case "13": + c.Version = ProtocolVersionHybi13 + default: + return http.StatusBadRequest, ErrBadWebSocketVersion + } + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) + if err != nil { + return http.StatusBadRequest, err + } + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + if protocol != "" { + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + } + c.accept, err = getNonceAccept([]byte(key)) + if err != nil { + return http.StatusInternalServerError, err + } + return http.StatusSwitchingProtocols, nil +} + +// Origin parses the Origin header in req. +// If the Origin header is not set, it returns nil and nil. +func Origin(config *Config, req *http.Request) (*url.URL, error) { + var origin string + switch config.Version { + case ProtocolVersionHybi13: + origin = req.Header.Get("Origin") + } + if origin == "" { + return nil, nil + } + return url.ParseRequestURI(origin) +} + +func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + // You need choose a Protocol in Handshake func in Server. + return ErrBadWebSocketProtocol + } + } + buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + buf.WriteString("Upgrade: websocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + // TODO(ukai): send Sec-WebSocket-Extensions. + if c.Header != nil { + err := c.Header.WriteSubset(buf, handshakeHeader) + if err != nil { + return err + } + } + buf.WriteString("\r\n") + return buf.Flush() +} + +func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiServerConn(c.Config, buf, rwc, request) +} + +// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. +func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiConn(config, buf, rwc, request) +} diff --git a/vendor/golang.org/x/net/websocket/server.go b/vendor/golang.org/x/net/websocket/server.go new file mode 100644 index 000000000..0895dea19 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/server.go @@ -0,0 +1,113 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "fmt" + "io" + "net/http" +) + +func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) { + var hs serverHandshaker = &hybiServerHandshaker{Config: config} + code, err := hs.ReadHandshake(buf.Reader, req) + if err == ErrBadWebSocketVersion { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if err != nil { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.WriteString(err.Error()) + buf.Flush() + return + } + if handshake != nil { + err = handshake(config, req) + if err != nil { + code = http.StatusForbidden + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + } + err = hs.AcceptHandshake(buf.Writer) + if err != nil { + code = http.StatusBadRequest + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.Flush() + return + } + conn = hs.NewServerConn(buf, rwc, req) + return +} + +// Server represents a server of a WebSocket. +type Server struct { + // Config is a WebSocket configuration for new WebSocket connection. + Config + + // Handshake is an optional function in WebSocket handshake. + // For example, you can check, or don't check Origin header. + // Another example, you can select config.Protocol. + Handshake func(*Config, *http.Request) error + + // Handler handles a WebSocket connection. + Handler +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s.serveWebSocket(w, req) +} + +func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) { + rwc, buf, err := w.(http.Hijacker).Hijack() + if err != nil { + panic("Hijack failed: " + err.Error()) + } + // The server should abort the WebSocket connection if it finds + // the client did not send a handshake that matches with protocol + // specification. + defer rwc.Close() + conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake) + if err != nil { + return + } + if conn == nil { + panic("unexpected nil conn") + } + s.Handler(conn) +} + +// Handler is a simple interface to a WebSocket browser client. +// It checks if Origin header is valid URL by default. +// You might want to verify websocket.Conn.Config().Origin in the func. +// If you use Server instead of Handler, you could call websocket.Origin and +// check the origin in your Handshake func. So, if you want to accept +// non-browser clients, which do not send an Origin header, set a +// Server.Handshake that does not check the origin. +type Handler func(*Conn) + +func checkOrigin(config *Config, req *http.Request) (err error) { + config.Origin, err = Origin(config, req) + if err == nil && config.Origin == nil { + return fmt.Errorf("null origin") + } + return err +} + +// ServeHTTP implements the http.Handler interface for a WebSocket +func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + s := Server{Handler: h, Handshake: checkOrigin} + s.serveWebSocket(w, req) +} diff --git a/vendor/golang.org/x/net/websocket/websocket.go b/vendor/golang.org/x/net/websocket/websocket.go new file mode 100644 index 000000000..3448d2039 --- /dev/null +++ b/vendor/golang.org/x/net/websocket/websocket.go @@ -0,0 +1,449 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package websocket implements a client and server for the WebSocket protocol +// as specified in RFC 6455. +// +// This package currently lacks some features found in an alternative +// and more actively maintained WebSocket packages: +// +// - [github.com/gorilla/websocket] +// - [github.com/coder/websocket] +package websocket // import "golang.org/x/net/websocket" + +import ( + "bufio" + "crypto/tls" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +const ( + ProtocolVersionHybi13 = 13 + ProtocolVersionHybi = ProtocolVersionHybi13 + SupportedProtocolVersion = "13" + + ContinuationFrame = 0 + TextFrame = 1 + BinaryFrame = 2 + CloseFrame = 8 + PingFrame = 9 + PongFrame = 10 + UnknownFrame = 255 + + DefaultMaxPayloadBytes = 32 << 20 // 32MB +) + +// ProtocolError represents WebSocket protocol errors. +type ProtocolError struct { + ErrorString string +} + +func (err *ProtocolError) Error() string { return err.ErrorString } + +var ( + ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} + ErrBadScheme = &ProtocolError{"bad scheme"} + ErrBadStatus = &ProtocolError{"bad status"} + ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} + ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} + ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} + ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} + ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} + ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} + ErrBadFrame = &ProtocolError{"bad frame"} + ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} + ErrNotWebSocket = &ProtocolError{"not websocket protocol"} + ErrBadRequestMethod = &ProtocolError{"bad method"} + ErrNotSupported = &ProtocolError{"not supported"} +) + +// ErrFrameTooLarge is returned by Codec's Receive method if payload size +// exceeds limit set by Conn.MaxPayloadBytes +var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit") + +// Addr is an implementation of net.Addr for WebSocket. +type Addr struct { + *url.URL +} + +// Network returns the network type for a WebSocket, "websocket". +func (addr *Addr) Network() string { return "websocket" } + +// Config is a WebSocket configuration +type Config struct { + // A WebSocket server address. + Location *url.URL + + // A Websocket client origin. + Origin *url.URL + + // WebSocket subprotocols. + Protocol []string + + // WebSocket protocol version. + Version int + + // TLS config for secure WebSocket (wss). + TlsConfig *tls.Config + + // Additional header fields to be sent in WebSocket opening handshake. + Header http.Header + + // Dialer used when opening websocket connections. + Dialer *net.Dialer + + handshakeData map[string]string +} + +// serverHandshaker is an interface to handle WebSocket server side handshake. +type serverHandshaker interface { + // ReadHandshake reads handshake request message from client. + // Returns http response code and error if any. + ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) + + // AcceptHandshake accepts the client handshake request and sends + // handshake response back to client. + AcceptHandshake(buf *bufio.Writer) (err error) + + // NewServerConn creates a new WebSocket connection. + NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) +} + +// frameReader is an interface to read a WebSocket frame. +type frameReader interface { + // Reader is to read payload of the frame. + io.Reader + + // PayloadType returns payload type. + PayloadType() byte + + // HeaderReader returns a reader to read header of the frame. + HeaderReader() io.Reader + + // TrailerReader returns a reader to read trailer of the frame. + // If it returns nil, there is no trailer in the frame. + TrailerReader() io.Reader + + // Len returns total length of the frame, including header and trailer. + Len() int +} + +// frameReaderFactory is an interface to creates new frame reader. +type frameReaderFactory interface { + NewFrameReader() (r frameReader, err error) +} + +// frameWriter is an interface to write a WebSocket frame. +type frameWriter interface { + // Writer is to write payload of the frame. + io.WriteCloser +} + +// frameWriterFactory is an interface to create new frame writer. +type frameWriterFactory interface { + NewFrameWriter(payloadType byte) (w frameWriter, err error) +} + +type frameHandler interface { + HandleFrame(frame frameReader) (r frameReader, err error) + WriteClose(status int) (err error) +} + +// Conn represents a WebSocket connection. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + config *Config + request *http.Request + + buf *bufio.ReadWriter + rwc io.ReadWriteCloser + + rio sync.Mutex + frameReaderFactory + frameReader + + wio sync.Mutex + frameWriterFactory + + frameHandler + PayloadType byte + defaultCloseStatus int + + // MaxPayloadBytes limits the size of frame payload received over Conn + // by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used. + MaxPayloadBytes int +} + +// Read implements the io.Reader interface: +// it reads data of a frame from the WebSocket connection. +// if msg is not large enough for the frame data, it fills the msg and next Read +// will read the rest of the frame data. +// it reads Text frame or Binary frame. +func (ws *Conn) Read(msg []byte) (n int, err error) { + ws.rio.Lock() + defer ws.rio.Unlock() +again: + if ws.frameReader == nil { + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return 0, err + } + ws.frameReader, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return 0, err + } + if ws.frameReader == nil { + goto again + } + } + n, err = ws.frameReader.Read(msg) + if err == io.EOF { + if trailer := ws.frameReader.TrailerReader(); trailer != nil { + io.Copy(io.Discard, trailer) + } + ws.frameReader = nil + goto again + } + return n, err +} + +// Write implements the io.Writer interface: +// it writes data as a frame to the WebSocket connection. +func (ws *Conn) Write(msg []byte) (n int, err error) { + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// Close implements the io.Closer interface. +func (ws *Conn) Close() error { + err := ws.frameHandler.WriteClose(ws.defaultCloseStatus) + err1 := ws.rwc.Close() + if err != nil { + return err + } + return err1 +} + +// IsClientConn reports whether ws is a client-side connection. +func (ws *Conn) IsClientConn() bool { return ws.request == nil } + +// IsServerConn reports whether ws is a server-side connection. +func (ws *Conn) IsServerConn() bool { return ws.request != nil } + +// LocalAddr returns the WebSocket Origin for the connection for client, or +// the WebSocket location for server. +func (ws *Conn) LocalAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Origin} + } + return &Addr{ws.config.Location} +} + +// RemoteAddr returns the WebSocket location for the connection for client, or +// the Websocket Origin for server. +func (ws *Conn) RemoteAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Location} + } + return &Addr{ws.config.Origin} +} + +var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn") + +// SetDeadline sets the connection's network read & write deadlines. +func (ws *Conn) SetDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetDeadline(t) + } + return errSetDeadline +} + +// SetReadDeadline sets the connection's network read deadline. +func (ws *Conn) SetReadDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetReadDeadline(t) + } + return errSetDeadline +} + +// SetWriteDeadline sets the connection's network write deadline. +func (ws *Conn) SetWriteDeadline(t time.Time) error { + if conn, ok := ws.rwc.(net.Conn); ok { + return conn.SetWriteDeadline(t) + } + return errSetDeadline +} + +// Config returns the WebSocket config. +func (ws *Conn) Config() *Config { return ws.config } + +// Request returns the http request upgraded to the WebSocket. +// It is nil for client side. +func (ws *Conn) Request() *http.Request { return ws.request } + +// Codec represents a symmetric pair of functions that implement a codec. +type Codec struct { + Marshal func(v interface{}) (data []byte, payloadType byte, err error) + Unmarshal func(data []byte, payloadType byte, v interface{}) (err error) +} + +// Send sends v marshaled by cd.Marshal as single frame to ws. +func (cd Codec) Send(ws *Conn, v interface{}) (err error) { + data, payloadType, err := cd.Marshal(v) + if err != nil { + return err + } + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(payloadType) + if err != nil { + return err + } + _, err = w.Write(data) + w.Close() + return err +} + +// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores +// in v. The whole frame payload is read to an in-memory buffer; max size of +// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds +// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire +// completely. The next call to Receive would read and discard leftover data of +// previous oversized frame before processing next frame. +func (cd Codec) Receive(ws *Conn, v interface{}) (err error) { + ws.rio.Lock() + defer ws.rio.Unlock() + if ws.frameReader != nil { + _, err = io.Copy(io.Discard, ws.frameReader) + if err != nil { + return err + } + ws.frameReader = nil + } +again: + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return err + } + frame, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return err + } + if frame == nil { + goto again + } + maxPayloadBytes := ws.MaxPayloadBytes + if maxPayloadBytes == 0 { + maxPayloadBytes = DefaultMaxPayloadBytes + } + if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) { + // payload size exceeds limit, no need to call Unmarshal + // + // set frameReader to current oversized frame so that + // the next call to this function can drain leftover + // data before processing the next frame + ws.frameReader = frame + return ErrFrameTooLarge + } + payloadType := frame.PayloadType() + data, err := io.ReadAll(frame) + if err != nil { + return err + } + return cd.Unmarshal(data, payloadType, v) +} + +func marshal(v interface{}) (msg []byte, payloadType byte, err error) { + switch data := v.(type) { + case string: + return []byte(data), TextFrame, nil + case []byte: + return data, BinaryFrame, nil + } + return nil, UnknownFrame, ErrNotSupported +} + +func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + switch data := v.(type) { + case *string: + *data = string(msg) + return nil + case *[]byte: + *data = msg + return nil + } + return ErrNotSupported +} + +/* +Message is a codec to send/receive text/binary data in a frame on WebSocket connection. +To send/receive text frame, use string type. +To send/receive binary frame, use []byte type. + +Trivial usage: + + import "websocket" + + // receive text frame + var message string + websocket.Message.Receive(ws, &message) + + // send text frame + message = "hello" + websocket.Message.Send(ws, message) + + // receive binary frame + var data []byte + websocket.Message.Receive(ws, &data) + + // send binary frame + data = []byte{0, 1, 2} + websocket.Message.Send(ws, data) +*/ +var Message = Codec{marshal, unmarshal} + +func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) { + msg, err = json.Marshal(v) + return msg, TextFrame, err +} + +func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) { + return json.Unmarshal(msg, v) +} + +/* +JSON is a codec to send/receive JSON data in a frame from a WebSocket connection. + +Trivial usage: + + import "websocket" + + type T struct { + Msg string + Count int + } + + // receive JSON type T + var data T + websocket.JSON.Receive(ws, &data) + + // send JSON type T + websocket.JSON.Send(ws, data) +*/ +var JSON = Codec{jsonMarshal, jsonUnmarshal} diff --git a/vendor/k8s.io/apimachinery/pkg/util/httpstream/doc.go b/vendor/k8s.io/apimachinery/pkg/util/httpstream/doc.go new file mode 100644 index 000000000..5fdc7955f --- /dev/null +++ b/vendor/k8s.io/apimachinery/pkg/util/httpstream/doc.go @@ -0,0 +1,20 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package httpstream contains compatibility wrappers for streaming transport APIs. +// +// Deprecated: use k8s.io/streaming/pkg/httpstream directly. +package httpstream diff --git a/vendor/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go b/vendor/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go new file mode 100644 index 000000000..a7c8d897d --- /dev/null +++ b/vendor/k8s.io/apimachinery/pkg/util/httpstream/httpstream.go @@ -0,0 +1,201 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package httpstream + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions" +) + +// NewStreamHandler defines a function that is called when a new Stream is +// received. If no error is returned, the Stream is accepted; otherwise, +// the stream is rejected. After the reply frame has been sent, replySent is closed. +type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error + +// NoOpNewStreamHandler is a stream handler that accepts a new stream and +// performs no other logic. +func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil } + +// Dialer knows how to open a streaming connection to a server. +type Dialer interface { + + // Dial opens a streaming connection to a server using one of the protocols + // specified (in order of most preferred to least preferred). + Dial(protocols ...string) (Connection, string, error) +} + +// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade +// HTTP requests to support multiplexed bidirectional streams. After RoundTrip() +// is invoked, if the upgrade is successful, clients may retrieve the upgraded +// connection by calling UpgradeRoundTripper.Connection(). +type UpgradeRoundTripper interface { + http.RoundTripper + // NewConnection validates the response and creates a new Connection. + NewConnection(resp *http.Response) (Connection, error) +} + +// ResponseUpgrader knows how to upgrade HTTP requests and responses to +// add streaming support to them. +type ResponseUpgrader interface { + // UpgradeResponse upgrades an HTTP response to one that supports multiplexed + // streams. newStreamHandler will be called asynchronously whenever the + // other end of the upgraded connection creates a new stream. + UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection +} + +// Connection represents an upgraded HTTP connection. +type Connection interface { + // CreateStream creates a new Stream with the supplied headers. + CreateStream(headers http.Header) (Stream, error) + // Close resets all streams and closes the connection. + Close() error + // CloseChan returns a channel that is closed when the underlying connection is closed. + CloseChan() <-chan bool + // SetIdleTimeout sets the amount of time the connection may remain idle before + // it is automatically closed. + SetIdleTimeout(timeout time.Duration) + // RemoveStreams can be used to remove a set of streams from the Connection. + RemoveStreams(streams ...Stream) +} + +// Stream represents a bidirectional communications channel that is part of an +// upgraded connection. +type Stream interface { + io.ReadWriteCloser + // Reset closes both directions of the stream, indicating that neither client + // or server can use it any more. + Reset() error + // Headers returns the headers used to create the stream. + Headers() http.Header + // Identifier returns the stream's ID. + Identifier() uint32 +} + +// UpgradeFailureError encapsulates the cause for why the streaming +// upgrade request failed. Implements error interface. +type UpgradeFailureError struct { + Cause error +} + +func (u *UpgradeFailureError) Error() string { + return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause) +} + +// IsUpgradeFailure returns true if the passed error is (or wrapped error contains) +// the UpgradeFailureError. +func IsUpgradeFailure(err error) bool { + if err == nil { + return false + } + var upgradeErr *UpgradeFailureError + return errors.As(err, &upgradeErr) +} + +// isHTTPSProxyError returns true if error is Gorilla/Websockets HTTPS Proxy dial error; +// false otherwise (see https://github.com/kubernetes/kubernetes/issues/126134). +func IsHTTPSProxyError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "proxy: unknown scheme: https") +} + +// IsUpgradeRequest returns true if the given request is a connection upgrade request +func IsUpgradeRequest(req *http.Request) bool { + for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] { + if strings.Contains(strings.ToLower(h), strings.ToLower(HeaderUpgrade)) { + return true + } + } + return false +} + +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + +func commaSeparatedHeaderValues(header []string) []string { + var parsedClientProtocols []string + for i := range header { + for _, clientProtocol := range strings.Split(header[i], ",") { + if proto := strings.Trim(clientProtocol, " "); len(proto) > 0 { + parsedClientProtocols = append(parsedClientProtocols, proto) + } + } + } + return parsedClientProtocols +} + +// Handshake performs a subprotocol negotiation. If the client did request a +// subprotocol, Handshake will select the first common value found in +// serverProtocols, otherwise it will return an error and write an HTTP BadRequest to the response. +// If a match is found, Handshake adds a response header indicating the chosen subprotocol. +// If no match is found, HTTP forbidden is returned, along with a response header containing +// the list of protocols the server can accept. +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) { + if len(serverProtocols) == 0 { + panic(fmt.Errorf("unable to upgrade: serverProtocols is required")) + } + values, ok := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] + if !ok { + err := fmt.Errorf("unable to upgrade: header %s does not exist in request with %d headers", HeaderProtocolVersion, len(req.Header)) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + if len(values) == 0 { + err := fmt.Errorf("unable to upgrade: header %s is empty", HeaderProtocolVersion) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + clientProtocols := commaSeparatedHeaderValues(values) + if len(clientProtocols) == 0 { + err := fmt.Errorf("unable to upgrade: header %s contains %s, but no valid protocols", HeaderProtocolVersion, values) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + + negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) + if len(negotiatedProtocol) == 0 { + for i := range serverProtocols { + w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i]) + } + err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols) + http.Error(w, err.Error(), http.StatusForbidden) + return "", err + } + + w.Header().Add(HeaderProtocolVersion, negotiatedProtocol) + return negotiatedProtocol, nil +} diff --git a/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/doc.go b/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/doc.go new file mode 100644 index 000000000..d03acb0ee --- /dev/null +++ b/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/doc.go @@ -0,0 +1,20 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package spdy contains compatibility wrappers for the SPDY transport stack. +// +// Deprecated: use k8s.io/streaming/pkg/httpstream/spdy directly. +package spdy diff --git a/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/spdy.go b/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/spdy.go new file mode 100644 index 000000000..37dfe8189 --- /dev/null +++ b/vendor/k8s.io/apimachinery/pkg/util/httpstream/spdy/spdy.go @@ -0,0 +1,236 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spdy + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + apihttpstream "k8s.io/apimachinery/pkg/util/httpstream" + streamhttp "k8s.io/streaming/pkg/httpstream" + streamspdy "k8s.io/streaming/pkg/httpstream/spdy" +) + +const HeaderSpdy31 = streamspdy.HeaderSpdy31 + +// SpdyRoundTripper is a compatibility wrapper around the streaming module's +// SPDY round tripper. +type SpdyRoundTripper struct { + delegate *streamspdy.SpdyRoundTripper +} + +func NewRoundTripper(tlsConfig *tls.Config) (*SpdyRoundTripper, error) { + delegate, err := streamspdy.NewRoundTripper(tlsConfig) + if err != nil { + return nil, err + } + return &SpdyRoundTripper{delegate: delegate}, nil +} + +func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) (*SpdyRoundTripper, error) { + delegate, err := streamspdy.NewRoundTripperWithProxy(tlsConfig, proxier) + if err != nil { + return nil, err + } + return &SpdyRoundTripper{delegate: delegate}, nil +} + +// RoundTripperConfig is a set of options for an SpdyRoundTripper. +type RoundTripperConfig struct { + // TLS configuration used by the round tripper if UpgradeTransport not present. + TLS *tls.Config + // Proxier is a proxy function invoked on each request. Optional. + Proxier func(*http.Request) (*url.URL, error) + // PingPeriod is a period for sending SPDY Pings on the connection. + // Optional. + PingPeriod time.Duration + // UpgradeTransport is a subtitute transport used for dialing. If set, + // this field will be used instead of "TLS" and "Proxier" for connection creation. + // Optional. + UpgradeTransport http.RoundTripper +} + +func NewRoundTripperWithConfig(cfg RoundTripperConfig) (*SpdyRoundTripper, error) { + delegate, err := streamspdy.NewRoundTripperWithConfig(streamspdy.RoundTripperConfig{ + TLS: cfg.TLS, + Proxier: cfg.Proxier, + PingPeriod: cfg.PingPeriod, + UpgradeTransport: cfg.UpgradeTransport, + }) + if err != nil { + return nil, err + } + return &SpdyRoundTripper{delegate: delegate}, nil +} + +// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during +// proxying with a spdy roundtripper. +func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config { + return s.delegate.TLSClientConfig() +} + +// Dial opens a network connection for an upgrade request. +func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { + return s.delegate.Dial(req) +} + +// RoundTrip executes a request and upgrades the connection. +func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return s.delegate.RoundTrip(req) +} + +// NewConnection validates a server upgrade response and prepares the transport. +func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (apihttpstream.Connection, error) { + conn, err := s.delegate.NewConnection(resp) + if err != nil { + return nil, err + } + return wrapConnection(conn), nil +} + +type responseUpgraderAdapter struct { + delegate streamhttp.ResponseUpgrader +} + +func (r *responseUpgraderAdapter) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler apihttpstream.NewStreamHandler) apihttpstream.Connection { + conn := r.delegate.UpgradeResponse(w, req, wrapNewStreamHandler(newStreamHandler)) + return wrapConnection(conn) +} + +func NewResponseUpgrader() apihttpstream.ResponseUpgrader { + return &responseUpgraderAdapter{delegate: streamspdy.NewResponseUpgrader()} +} + +func NewResponseUpgraderWithPings(pingPeriod time.Duration) apihttpstream.ResponseUpgrader { + return &responseUpgraderAdapter{delegate: streamspdy.NewResponseUpgraderWithPings(pingPeriod)} +} + +func NewClientConnection(conn net.Conn) (apihttpstream.Connection, error) { + c, err := streamspdy.NewClientConnection(conn) + if err != nil { + return nil, err + } + return wrapConnection(c), nil +} + +func NewClientConnectionWithPings(conn net.Conn, pingPeriod time.Duration) (apihttpstream.Connection, error) { + c, err := streamspdy.NewClientConnectionWithPings(conn, pingPeriod) + if err != nil { + return nil, err + } + return wrapConnection(c), nil +} + +func NewServerConnection(conn net.Conn, newStreamHandler apihttpstream.NewStreamHandler) (apihttpstream.Connection, error) { + c, err := streamspdy.NewServerConnection(conn, wrapNewStreamHandler(newStreamHandler)) + if err != nil { + return nil, err + } + return wrapConnection(c), nil +} + +func NewServerConnectionWithPings(conn net.Conn, newStreamHandler apihttpstream.NewStreamHandler, pingPeriod time.Duration) (apihttpstream.Connection, error) { + c, err := streamspdy.NewServerConnectionWithPings(conn, wrapNewStreamHandler(newStreamHandler), pingPeriod) + if err != nil { + return nil, err + } + return wrapConnection(c), nil +} + +type streamAdapter struct { + delegate streamhttp.Stream +} + +func (s *streamAdapter) Read(p []byte) (int, error) { + return s.delegate.Read(p) +} + +func (s *streamAdapter) Write(p []byte) (int, error) { + return s.delegate.Write(p) +} + +func (s *streamAdapter) Close() error { + return s.delegate.Close() +} + +func (s *streamAdapter) Reset() error { + return s.delegate.Reset() +} + +func (s *streamAdapter) Headers() http.Header { + return s.delegate.Headers() +} + +func (s *streamAdapter) Identifier() uint32 { + return s.delegate.Identifier() +} + +type connectionAdapter struct { + delegate streamhttp.Connection +} + +func (c *connectionAdapter) CreateStream(headers http.Header) (apihttpstream.Stream, error) { + stream, err := c.delegate.CreateStream(headers) + if err != nil { + return nil, err + } + return &streamAdapter{delegate: stream}, nil +} + +func (c *connectionAdapter) Close() error { + return c.delegate.Close() +} + +func (c *connectionAdapter) CloseChan() <-chan bool { + return c.delegate.CloseChan() +} + +func (c *connectionAdapter) SetIdleTimeout(timeout time.Duration) { + c.delegate.SetIdleTimeout(timeout) +} + +func (c *connectionAdapter) RemoveStreams(streams ...apihttpstream.Stream) { + streamingStreams := make([]streamhttp.Stream, 0, len(streams)) + for _, stream := range streams { + if stream == nil { + continue + } + if s, ok := stream.(streamhttp.Stream); ok { + streamingStreams = append(streamingStreams, s) + } + } + c.delegate.RemoveStreams(streamingStreams...) +} + +func wrapConnection(conn streamhttp.Connection) apihttpstream.Connection { + if conn == nil { + return nil + } + return &connectionAdapter{delegate: conn} +} + +func wrapNewStreamHandler(newStreamHandler apihttpstream.NewStreamHandler) streamhttp.NewStreamHandler { + if newStreamHandler == nil { + return nil + } + return func(stream streamhttp.Stream, replySent <-chan struct{}) error { + return newStreamHandler(&streamAdapter{delegate: stream}, replySent) + } +} diff --git a/vendor/k8s.io/apimachinery/pkg/util/remotecommand/constants.go b/vendor/k8s.io/apimachinery/pkg/util/remotecommand/constants.go new file mode 100644 index 000000000..ba153ee24 --- /dev/null +++ b/vendor/k8s.io/apimachinery/pkg/util/remotecommand/constants.go @@ -0,0 +1,67 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + DefaultStreamCreationTimeout = 30 * time.Second + + // The SPDY subprotocol "channel.k8s.io" is used for remote command + // attachment/execution. This represents the initial unversioned subprotocol, + // which has the known bugs https://issues.k8s.io/13394 and + // https://issues.k8s.io/13395. + StreamProtocolV1Name = "channel.k8s.io" + + // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command + // attachment/execution. It is the second version of the subprotocol and + // resolves the issues present in the first version. + StreamProtocolV2Name = "v2.channel.k8s.io" + + // The SPDY subprotocol "v3.channel.k8s.io" is used for remote command + // attachment/execution. It is the third version of the subprotocol and + // adds support for resizing container terminals. + StreamProtocolV3Name = "v3.channel.k8s.io" + + // The SPDY subprotocol "v4.channel.k8s.io" is used for remote command + // attachment/execution. It is the 4th version of the subprotocol and + // adds support for exit codes. + StreamProtocolV4Name = "v4.channel.k8s.io" + + // The subprotocol "v5.channel.k8s.io" is used for remote command + // attachment/execution. It is the 5th version of the subprotocol and + // adds support for a CLOSE signal. + StreamProtocolV5Name = "v5.channel.k8s.io" + + NonZeroExitCodeReason = metav1.StatusReason("NonZeroExitCode") + ExitCodeCauseType = metav1.CauseType("ExitCode") + + // RemoteCommand stream identifiers. The first three identifiers (for STDIN, + // STDOUT, STDERR) are the same as their file descriptors. + StreamStdIn = 0 + StreamStdOut = 1 + StreamStdErr = 2 + StreamErr = 3 + StreamResize = 4 + StreamClose = 255 +) + +var SupportedStreamingProtocols = []string{StreamProtocolV4Name, StreamProtocolV3Name, StreamProtocolV2Name, StreamProtocolV1Name} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/OWNERS b/vendor/k8s.io/client-go/tools/remotecommand/OWNERS new file mode 100644 index 000000000..307848307 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/OWNERS @@ -0,0 +1,10 @@ +# See the OWNERS docs at https://go.k8s.io/owners + +approvers: + - aojea + - liggitt + - seans3 +reviewers: + - aojea + - liggitt + - seans3 diff --git a/vendor/k8s.io/client-go/tools/remotecommand/doc.go b/vendor/k8s.io/client-go/tools/remotecommand/doc.go new file mode 100644 index 000000000..b9f0db2d9 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/doc.go @@ -0,0 +1,20 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package remotecommand adds support for executing commands in containers, +// with support for separate stdin, stdout, and stderr streams, as well as +// TTY. +package remotecommand diff --git a/vendor/k8s.io/client-go/tools/remotecommand/errorstream.go b/vendor/k8s.io/client-go/tools/remotecommand/errorstream.go new file mode 100644 index 000000000..23dd50778 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/errorstream.go @@ -0,0 +1,55 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/klog/v2" +) + +// errorStreamDecoder interprets the data on the error channel and creates a go error object from it. +type errorStreamDecoder interface { + decode(message []byte) error +} + +// watchErrorStream watches the errorStream for remote command error data, +// decodes it with the given errorStreamDecoder, sends the decoded error (or nil if the remote +// command exited successfully) to the returned error channel, and closes it. +// This function returns immediately. +func watchErrorStream(logger klog.Logger, errorStream io.Reader, d errorStreamDecoder) chan error { + errorChan := make(chan error) + + go func() { + defer runtime.HandleCrashWithLogger(logger) + + message, err := io.ReadAll(errorStream) + switch { + case err != nil && err != io.EOF: + errorChan <- fmt.Errorf("error reading from error stream: %w", err) + case len(message) > 0: + errorChan <- d.decode(message) + default: + errorChan <- nil + } + close(errorChan) + }() + + return errorChan +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/fallback.go b/vendor/k8s.io/client-go/tools/remotecommand/fallback.go new file mode 100644 index 000000000..bcd5fd313 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/fallback.go @@ -0,0 +1,60 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + + "k8s.io/klog/v2" +) + +var _ Executor = &FallbackExecutor{} + +type FallbackExecutor struct { + primary Executor + secondary Executor + shouldFallback func(error) bool +} + +// NewFallbackExecutor creates an Executor that first attempts to use the +// WebSocketExecutor, falling back to the legacy SPDYExecutor if the initial +// websocket "StreamWithContext" call fails. +// func NewFallbackExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { +func NewFallbackExecutor(primary, secondary Executor, shouldFallback func(error) bool) (Executor, error) { + return &FallbackExecutor{ + primary: primary, + secondary: secondary, + shouldFallback: shouldFallback, + }, nil +} + +// Stream is deprecated. Please use "StreamWithContext". +func (f *FallbackExecutor) Stream(options StreamOptions) error { + return f.StreamWithContext(context.Background(), options) +} + +// StreamWithContext initially attempts to call "StreamWithContext" using the +// primary executor, falling back to calling the secondary executor if the +// initial primary call to upgrade to a websocket connection fails. +func (f *FallbackExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + err := f.primary.StreamWithContext(ctx, options) + if err != nil && f.shouldFallback(err) { + klog.FromContext(ctx).V(4).Info("RemoteCommand fallback", "err", err) + return f.secondary.StreamWithContext(ctx, options) + } + return err +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/reader.go b/vendor/k8s.io/client-go/tools/remotecommand/reader.go new file mode 100644 index 000000000..d1f1be34c --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/reader.go @@ -0,0 +1,41 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "io" +) + +// readerWrapper delegates to an io.Reader so that only the io.Reader interface is implemented, +// to keep io.Copy from doing things we don't want when copying from the reader to the data stream. +// +// If the Stdin io.Reader provided to remotecommand implements a WriteTo function (like bytes.Buffer does[1]), +// io.Copy calls that method[2] to attempt to write the entire buffer to the stream in one call. +// That results in an oversized call to spdystream.Stream#Write [3], +// which results in a single oversized data frame[4] that is too large. +// +// [1] https://golang.org/pkg/bytes/#Buffer.WriteTo +// [2] https://golang.org/pkg/io/#Copy +// [3] https://github.com/kubernetes/kubernetes/blob/90295640ef87db9daa0144c5617afe889e7992b2/vendor/github.com/docker/spdystream/stream.go#L66-L73 +// [4] https://github.com/kubernetes/kubernetes/blob/90295640ef87db9daa0144c5617afe889e7992b2/vendor/github.com/docker/spdystream/spdy/write.go#L302-L304 +type readerWrapper struct { + reader io.Reader +} + +func (r readerWrapper) Read(p []byte) (int, error) { + return r.reader.Read(p) +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/remotecommand.go b/vendor/k8s.io/client-go/tools/remotecommand/remotecommand.go new file mode 100644 index 000000000..ca892f9b7 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/remotecommand.go @@ -0,0 +1,59 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + "io" + "net/http" + + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" +) + +// StreamOptions holds information pertaining to the current streaming session: +// input/output streams, if the client is requesting a TTY, and a terminal size queue to +// support terminal resizing. +type StreamOptions struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer + Tty bool + TerminalSizeQueue TerminalSizeQueue +} + +// Executor is an interface for transporting shell-style streams. +type Executor interface { + // Deprecated: use StreamWithContext instead to avoid possible resource leaks. + // See https://github.com/kubernetes/kubernetes/pull/103177 for details. + Stream(options StreamOptions) error + + // StreamWithContext initiates the transport of the standard shell streams. It will + // transport any non-nil stream to a remote system, and return an error if a problem + // occurs. If tty is set, the stderr stream is not used (raw TTY manages stdout and + // stderr over the stdout stream). + // The context controls the entire lifetime of stream execution. + StreamWithContext(ctx context.Context, options StreamOptions) error +} + +type streamCreator interface { + CreateStream(headers http.Header) (httpstream.Stream, error) +} + +type streamProtocolHandler interface { + stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/resize.go b/vendor/k8s.io/client-go/tools/remotecommand/resize.go new file mode 100644 index 000000000..2815112a5 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/resize.go @@ -0,0 +1,34 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +// TerminalSize represents the width and height of a terminal. +// It is the same as staging/src/k8s.io/kubectl/pkg/util/term.TerminalSize. +// Copied to decouple the packages. Terminal-related package should not depend on API client and vice versa. +type TerminalSize struct { + Width uint16 + Height uint16 +} + +// TerminalSizeQueue is capable of returning terminal resize events as they occur. +// It is the same as staging/src/k8s.io/kubectl/pkg/util/term.TerminalSizeQueue. +// Copied to decouple the packages. Terminal-related package should not depend on API client and vice versa. +type TerminalSizeQueue interface { + // Next returns the new terminal size after the terminal has been resized. It returns nil when + // monitoring has been stopped. + Next() *TerminalSize +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/spdy.go b/vendor/k8s.io/client-go/tools/remotecommand/spdy.go new file mode 100644 index 000000000..ebf3c53c3 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/spdy.go @@ -0,0 +1,176 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "k8s.io/apimachinery/pkg/util/remotecommand" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/spdy" + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" +) + +// spdyStreamExecutor handles transporting standard shell streams over an httpstream connection. +type spdyStreamExecutor struct { + upgrader spdy.Upgrader + transport http.RoundTripper + + method string + url *url.URL + protocols []string + rejectRedirects bool // if true, receiving redirect from upstream is an error +} + +// NewSPDYExecutor connects to the provided server and upgrades the connection to +// multiplexed bidirectional streams. +func NewSPDYExecutor(config *restclient.Config, method string, url *url.URL) (Executor, error) { + wrapper, upgradeRoundTripper, err := spdy.RoundTripperFor(config) + if err != nil { + return nil, err + } + return NewSPDYExecutorForTransports(wrapper, upgradeRoundTripper, method, url) +} + +// NewSPDYExecutorRejectRedirects returns an Executor that will upgrade the future +// connection to a SPDY bi-directional streaming connection when calling "Stream" (deprecated) +// or "StreamWithContext" (preferred). Additionally, if the upstream server returns a redirect +// during the attempted upgrade in these "Stream" calls, an error is returned. +func NewSPDYExecutorRejectRedirects(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + executor, err := NewSPDYExecutorForTransports(transport, upgrader, method, url) + if err != nil { + return nil, err + } + spdyExecutor := executor.(*spdyStreamExecutor) + spdyExecutor.rejectRedirects = true + return spdyExecutor, nil +} + +// NewSPDYExecutorForTransports connects to the provided server using the given transport, +// upgrades the response using the given upgrader to multiplexed bidirectional streams. +func NewSPDYExecutorForTransports(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL) (Executor, error) { + return NewSPDYExecutorForProtocols( + transport, upgrader, method, url, + remotecommand.StreamProtocolV5Name, + remotecommand.StreamProtocolV4Name, + remotecommand.StreamProtocolV3Name, + remotecommand.StreamProtocolV2Name, + remotecommand.StreamProtocolV1Name, + ) +} + +// NewSPDYExecutorForProtocols connects to the provided server and upgrades the connection to +// multiplexed bidirectional streams using only the provided protocols. Exposed for testing, most +// callers should use NewSPDYExecutor or NewSPDYExecutorForTransports. +func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgrader, method string, url *url.URL, protocols ...string) (Executor, error) { + return &spdyStreamExecutor{ + upgrader: upgrader, + transport: transport, + method: method, + url: url, + protocols: protocols, + }, nil +} + +// Stream opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects. +func (e *spdyStreamExecutor) Stream(options StreamOptions) error { + return e.StreamWithContext(context.Background(), options) +} + +// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it. +func (e *spdyStreamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) { + req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil) + if err != nil { + return nil, nil, fmt.Errorf("error creating request: %v", err) + } + + client := http.Client{Transport: e.transport} + if e.rejectRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return fmt.Errorf("redirect not allowed") + } + } + conn, protocol, err := spdy.NegotiateStreaming( + e.upgrader, + &client, + req, + e.protocols..., + ) + if err != nil { + return nil, nil, err + } + + var streamer streamProtocolHandler + + logger := klog.FromContext(ctx) + switch protocol { + case remotecommand.StreamProtocolV5Name: + streamer = newStreamProtocolV5(options) + case remotecommand.StreamProtocolV4Name: + streamer = newStreamProtocolV4(options) + case remotecommand.StreamProtocolV3Name: + streamer = newStreamProtocolV3(options) + case remotecommand.StreamProtocolV2Name: + streamer = newStreamProtocolV2(options) + case "": + logger.V(4).Info("The server did not negotiate a streaming protocol version, falling back", "protocol", remotecommand.StreamProtocolV1Name) + fallthrough + case remotecommand.StreamProtocolV1Name: + streamer = newStreamProtocolV1(options) + } + + return conn, streamer, nil +} + +// StreamWithContext opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects or the context is done. +func (e *spdyStreamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + conn, streamer, err := e.newConnectionAndStream(ctx, options) + if err != nil { + return err + } + defer conn.Close() + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + // The SPDY executor does not need to synchronize stream creation, so we pass a nil + // ready channel. The underlying spdystream library handles stream multiplexing + // without a race condition. + errorChan <- streamer.stream(klog.FromContext(ctx), conn, nil) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/v1.go b/vendor/k8s.io/client-go/tools/remotecommand/v1.go new file mode 100644 index 000000000..5d903b67a --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/v1.go @@ -0,0 +1,164 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "net/http" + + v1 "k8s.io/api/core/v1" + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" +) + +// streamProtocolV1 implements the first version of the streaming exec & attach +// protocol. This version has some bugs, such as not being able to detect when +// non-interactive stdin data has ended. See https://issues.k8s.io/13394 and +// https://issues.k8s.io/13395 for more details. +type streamProtocolV1 struct { + StreamOptions + + errorStream httpstream.Stream + remoteStdin httpstream.Stream + remoteStdout httpstream.Stream + remoteStderr httpstream.Stream +} + +var _ streamProtocolHandler = &streamProtocolV1{} + +func newStreamProtocolV1(options StreamOptions) streamProtocolHandler { + return &streamProtocolV1{ + StreamOptions: options, + } +} + +func (p *streamProtocolV1) stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error { + doneChan := make(chan struct{}, 2) + errorChan := make(chan error) + + cp := func(s string, dst io.Writer, src io.Reader) { + logger.V(6).Info("Copying", "data", s) + defer logger.V(6).Info("Done copying", "data", s) + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + logger.Error(err, "Error copying", "data", s) + } + if s == v1.StreamTypeStdout || s == v1.StreamTypeStderr { + doneChan <- struct{}{} + } + } + + // set up all the streams first + var err error + headers := http.Header{} + headers.Set(v1.StreamType, v1.StreamTypeError) + p.errorStream, err = conn.CreateStream(headers) + if err != nil { + return err + } + defer p.errorStream.Reset() + + // Create all the streams first, then start the copy goroutines. The server doesn't start its copy + // goroutines until it's received all of the streams. If the client creates the stdin stream and + // immediately begins copying stdin data to the server, it's possible to overwhelm and wedge the + // spdy frame handler in the server so that it is full of unprocessed frames. The frames aren't + // getting processed because the server hasn't started its copying, and it won't do that until it + // gets all the streams. By creating all the streams first, we ensure that the server is ready to + // process data before the client starts sending any. See https://issues.k8s.io/16373 for more info. + if p.Stdin != nil { + headers.Set(v1.StreamType, v1.StreamTypeStdin) + p.remoteStdin, err = conn.CreateStream(headers) + if err != nil { + return err + } + defer p.remoteStdin.Reset() + } + + if p.Stdout != nil { + headers.Set(v1.StreamType, v1.StreamTypeStdout) + p.remoteStdout, err = conn.CreateStream(headers) + if err != nil { + return err + } + defer p.remoteStdout.Reset() + } + + if p.Stderr != nil && !p.Tty { + headers.Set(v1.StreamType, v1.StreamTypeStderr) + p.remoteStderr, err = conn.CreateStream(headers) + if err != nil { + return err + } + defer p.remoteStderr.Reset() + } + + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + + // now that all the streams have been created, proceed with reading & copying + + // always read from errorStream + go func() { + message, err := io.ReadAll(p.errorStream) + if err != nil && err != io.EOF { + errorChan <- fmt.Errorf("Error reading from error stream: %s", err) + return + } + if len(message) > 0 { + errorChan <- fmt.Errorf("Error executing remote command: %s", message) + return + } + }() + + if p.Stdin != nil { + // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) + // because stdin is not closed until the process exits. If we try to call + // stdin.Close(), it returns no error but doesn't unblock the copy. It will + // exit when the process exits, instead. + go cp(v1.StreamTypeStdin, p.remoteStdin, readerWrapper{p.Stdin}) + } + + waitCount := 0 + completedStreams := 0 + + if p.Stdout != nil { + waitCount++ + go cp(v1.StreamTypeStdout, p.Stdout, p.remoteStdout) + } + + if p.Stderr != nil && !p.Tty { + waitCount++ + go cp(v1.StreamTypeStderr, p.Stderr, p.remoteStderr) + } + +Loop: + for { + select { + case <-doneChan: + completedStreams++ + if completedStreams == waitCount { + break Loop + } + case err := <-errorChan: + return err + } + } + + return nil +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/v2.go b/vendor/k8s.io/client-go/tools/remotecommand/v2.go new file mode 100644 index 000000000..75286a12f --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/v2.go @@ -0,0 +1,205 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "net/http" + "sync" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/klog/v2" +) + +// streamProtocolV2 implements version 2 of the streaming protocol for attach +// and exec. The original streaming protocol was metav1. As a result, this +// version is referred to as version 2, even though it is the first actual +// numbered version. +type streamProtocolV2 struct { + StreamOptions + + errorStream io.Reader + remoteStdin io.ReadWriteCloser + remoteStdout io.Reader + remoteStderr io.Reader +} + +var _ streamProtocolHandler = &streamProtocolV2{} + +func newStreamProtocolV2(options StreamOptions) streamProtocolHandler { + return &streamProtocolV2{ + StreamOptions: options, + } +} + +func (p *streamProtocolV2) createStreams(conn streamCreator) error { + var err error + headers := http.Header{} + + // set up error stream + headers.Set(v1.StreamType, v1.StreamTypeError) + p.errorStream, err = conn.CreateStream(headers) + if err != nil { + return err + } + + // set up stdin stream + if p.Stdin != nil { + headers.Set(v1.StreamType, v1.StreamTypeStdin) + p.remoteStdin, err = conn.CreateStream(headers) + if err != nil { + return err + } + } + + // set up stdout stream + if p.Stdout != nil { + headers.Set(v1.StreamType, v1.StreamTypeStdout) + p.remoteStdout, err = conn.CreateStream(headers) + if err != nil { + return err + } + } + + // set up stderr stream + if p.Stderr != nil && !p.Tty { + headers.Set(v1.StreamType, v1.StreamTypeStderr) + p.remoteStderr, err = conn.CreateStream(headers) + if err != nil { + return err + } + } + return nil +} + +func (p *streamProtocolV2) copyStdin(logger klog.Logger) { + if p.Stdin != nil { + var once sync.Once + + // copy from client's stdin to container's stdin + go func() { + defer runtime.HandleCrashWithLogger(logger) + + // if p.stdin is noninteractive, p.g. `echo abc | kubectl exec -i -- cat`, make sure + // we close remoteStdin as soon as the copy from p.stdin to remoteStdin finishes. Otherwise + // the executed command will remain running. + defer once.Do(func() { p.remoteStdin.Close() }) + + if _, err := io.Copy(p.remoteStdin, readerWrapper{p.Stdin}); err != nil { + runtime.HandleErrorWithLogger(logger, err, "Copying stdin failed") + } + }() + + // read from remoteStdin until the stream is closed. this is essential to + // be able to exit interactive sessions cleanly and not leak goroutines or + // hang the client's terminal. + // + // TODO we aren't using go-dockerclient any more; revisit this to determine if it's still + // required by engine-api. + // + // go-dockerclient's current hijack implementation + // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) + // waits for all three streams (stdin/stdout/stderr) to finish copying + // before returning. When hijack finishes copying stdout/stderr, it calls + // Close() on its side of remoteStdin, which allows this copy to complete. + // When that happens, we must Close() on our side of remoteStdin, to + // allow the copy in hijack to complete, and hijack to return. + go func() { + defer runtime.HandleCrashWithLogger(logger) + defer once.Do(func() { p.remoteStdin.Close() }) + + // this "copy" doesn't actually read anything - it's just here to wait for + // the server to close remoteStdin. + if _, err := io.Copy(io.Discard, p.remoteStdin); err != nil { + runtime.HandleErrorWithLogger(logger, err, "Waiting for server to close stdin failed") + } + }() + } +} + +func (p *streamProtocolV2) copyStdout(logger klog.Logger, wg *sync.WaitGroup) { + if p.Stdout == nil { + return + } + + wg.Add(1) + go func() { + defer runtime.HandleCrashWithLogger(logger) + defer wg.Done() + // make sure, packet in queue can be consumed. + // block in queue may lead to deadlock in conn.server + // issue: https://github.com/kubernetes/kubernetes/issues/96339 + defer io.Copy(io.Discard, p.remoteStdout) + + if _, err := io.Copy(p.Stdout, p.remoteStdout); err != nil { + runtime.HandleErrorWithLogger(logger, err, "Copying stdout failed") + } + }() +} + +func (p *streamProtocolV2) copyStderr(logger klog.Logger, wg *sync.WaitGroup) { + if p.Stderr == nil || p.Tty { + return + } + + wg.Add(1) + go func() { + defer runtime.HandleCrashWithLogger(logger) + defer wg.Done() + defer io.Copy(io.Discard, p.remoteStderr) + + if _, err := io.Copy(p.Stderr, p.remoteStderr); err != nil { + runtime.HandleErrorWithLogger(logger, err, "Copying stderr failed") + } + }() +} + +func (p *streamProtocolV2) stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error { + if err := p.createStreams(conn); err != nil { + return err + } + + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + + // now that all the streams have been created, proceed with reading & copying + + errorChan := watchErrorStream(logger, p.errorStream, &errorDecoderV2{}) + + p.copyStdin(logger) + + var wg sync.WaitGroup + p.copyStdout(logger, &wg) + p.copyStderr(logger, &wg) + + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan +} + +// errorDecoderV2 interprets the error channel data as plain text. +type errorDecoderV2 struct{} + +func (d *errorDecoderV2) decode(message []byte) error { + return fmt.Errorf("error executing remote command: %s", message) +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/v3.go b/vendor/k8s.io/client-go/tools/remotecommand/v3.go new file mode 100644 index 000000000..b1e533a8a --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/v3.go @@ -0,0 +1,117 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "encoding/json" + "io" + "net/http" + "sync" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/klog/v2" +) + +// streamProtocolV3 implements version 3 of the streaming protocol for attach +// and exec. This version adds support for resizing the container's terminal. +type streamProtocolV3 struct { + *streamProtocolV2 + + resizeStream io.Writer +} + +var _ streamProtocolHandler = &streamProtocolV3{} + +func newStreamProtocolV3(options StreamOptions) streamProtocolHandler { + return &streamProtocolV3{ + streamProtocolV2: newStreamProtocolV2(options).(*streamProtocolV2), + } +} + +func (p *streamProtocolV3) createStreams(conn streamCreator) error { + // set up the streams from v2 + if err := p.streamProtocolV2.createStreams(conn); err != nil { + return err + } + + // set up resize stream + if p.Tty { + headers := http.Header{} + headers.Set(v1.StreamType, v1.StreamTypeResize) + var err error + p.resizeStream, err = conn.CreateStream(headers) + if err != nil { + return err + } + } + + return nil +} + +func (p *streamProtocolV3) handleResizes(logger klog.Logger) { + if p.resizeStream == nil || p.TerminalSizeQueue == nil { + return + } + go func() { + defer runtime.HandleCrashWithLogger(logger) + + encoder := json.NewEncoder(p.resizeStream) + for { + size := p.TerminalSizeQueue.Next() + if size == nil { + return + } + if err := encoder.Encode(&size); err != nil { + runtime.HandleErrorWithLogger(logger, err, "Encoding terminal size failed") + } + } + }() +} + +func (p *streamProtocolV3) stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error { + if err := p.createStreams(conn); err != nil { + return err + } + + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + + // now that all the streams have been created, proceed with reading & copying + + errorChan := watchErrorStream(logger, p.errorStream, &errorDecoderV3{}) + + p.handleResizes(logger) + + p.copyStdin(logger) + + var wg sync.WaitGroup + p.copyStdout(logger, &wg) + p.copyStderr(logger, &wg) + + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan +} + +type errorDecoderV3 struct { + errorDecoderV2 +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/v4.go b/vendor/k8s.io/client-go/tools/remotecommand/v4.go new file mode 100644 index 000000000..47018ba5f --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/v4.go @@ -0,0 +1,125 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "sync" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/remotecommand" + "k8s.io/client-go/util/exec" + "k8s.io/klog/v2" +) + +// streamProtocolV4 implements version 4 of the streaming protocol for attach +// and exec. This version adds support for exit codes on the error stream through +// the use of metav1.Status instead of plain text messages. +type streamProtocolV4 struct { + *streamProtocolV3 +} + +var _ streamProtocolHandler = &streamProtocolV4{} + +func newStreamProtocolV4(options StreamOptions) streamProtocolHandler { + return &streamProtocolV4{ + streamProtocolV3: newStreamProtocolV3(options).(*streamProtocolV3), + } +} + +func (p *streamProtocolV4) createStreams(conn streamCreator) error { + return p.streamProtocolV3.createStreams(conn) +} + +func (p *streamProtocolV4) handleResizes(logger klog.Logger) { + p.streamProtocolV3.handleResizes(logger) +} + +func (p *streamProtocolV4) stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error { + if err := p.createStreams(conn); err != nil { + return err + } + + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + + // now that all the streams have been created, proceed with reading & copying + + errorChan := watchErrorStream(logger, p.errorStream, &errorDecoderV4{}) + + p.handleResizes(logger) + + p.copyStdin(logger) + + var wg sync.WaitGroup + p.copyStdout(logger, &wg) + p.copyStderr(logger, &wg) + + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan +} + +// errorDecoderV4 interprets the json-marshaled metav1.Status on the error channel +// and creates an exec.ExitError from it. +type errorDecoderV4 struct{} + +func (d *errorDecoderV4) decode(message []byte) error { + status := metav1.Status{} + err := json.Unmarshal(message, &status) + if err != nil { + return fmt.Errorf("error stream protocol error: %v in %q", err, string(message)) + } + switch status.Status { + case metav1.StatusSuccess: + return nil + case metav1.StatusFailure: + if status.Reason == remotecommand.NonZeroExitCodeReason { + if status.Details == nil { + return errors.New("error stream protocol error: details must be set") + } + for i := range status.Details.Causes { + c := &status.Details.Causes[i] + if c.Type != remotecommand.ExitCodeCauseType { + continue + } + + rc, err := strconv.ParseUint(c.Message, 10, 8) + if err != nil { + return fmt.Errorf("error stream protocol error: invalid exit code value %q", c.Message) + } + return exec.CodeExitError{ + Err: fmt.Errorf("command terminated with exit code %d", rc), + Code: int(rc), + } + } + + return fmt.Errorf("error stream protocol error: no %s cause given", remotecommand.ExitCodeCauseType) + } + default: + return errors.New("error stream protocol error: unknown error") + } + + return errors.New(status.Message) +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/v5.go b/vendor/k8s.io/client-go/tools/remotecommand/v5.go new file mode 100644 index 000000000..ca79a8828 --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/v5.go @@ -0,0 +1,37 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import "k8s.io/klog/v2" + +// streamProtocolV5 add support for V5 of the remote command subprotocol. +// For the streamProtocolHandler, this version is the same as V4. +type streamProtocolV5 struct { + *streamProtocolV4 +} + +var _ streamProtocolHandler = &streamProtocolV5{} + +func newStreamProtocolV5(options StreamOptions) streamProtocolHandler { + return &streamProtocolV5{ + streamProtocolV4: newStreamProtocolV4(options).(*streamProtocolV4), + } +} + +func (p *streamProtocolV5) stream(logger klog.Logger, conn streamCreator, ready chan<- struct{}) error { + return p.streamProtocolV4.stream(logger, conn, ready) +} diff --git a/vendor/k8s.io/client-go/tools/remotecommand/websocket.go b/vendor/k8s.io/client-go/tools/remotecommand/websocket.go new file mode 100644 index 000000000..f531e7ccf --- /dev/null +++ b/vendor/k8s.io/client-go/tools/remotecommand/websocket.go @@ -0,0 +1,537 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + gwebsocket "github.com/gorilla/websocket" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/remotecommand" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport/websocket" + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" +) + +// writeDeadline defines the time that a client-side write to the websocket +// connection must complete before an i/o timeout occurs. +const writeDeadline = 60 * time.Second + +var ( + _ Executor = &wsStreamExecutor{} + _ streamCreator = &wsStreamCreator{} + _ httpstream.Stream = &stream{} + + streamType2streamID = map[string]byte{ + v1.StreamTypeStdin: remotecommand.StreamStdIn, + v1.StreamTypeStdout: remotecommand.StreamStdOut, + v1.StreamTypeStderr: remotecommand.StreamStdErr, + v1.StreamTypeError: remotecommand.StreamErr, + v1.StreamTypeResize: remotecommand.StreamResize, + } +) + +const ( + // pingPeriod defines how often a heartbeat "ping" message is sent. + pingPeriod = 5 * time.Second + // pingReadDeadline defines the time waiting for a response heartbeat + // "pong" message before a timeout error occurs for websocket reading. + // This duration must always be greater than the "pingPeriod". By defining + // this deadline in terms of the ping period, we are essentially saying + // we can drop "X" (e.g. 12) pings before firing the timeout. + pingReadDeadline = (pingPeriod * 12) + (1 * time.Second) +) + +// wsStreamExecutor handles transporting standard shell streams over an httpstream connection. +type wsStreamExecutor struct { + transport http.RoundTripper + upgrader websocket.ConnectionHolder + method string + url string + // requested protocols in priority order (e.g. v5.channel.k8s.io before v4.channel.k8s.io). + protocols []string + // selected protocol from the handshake process; could be empty string if handshake fails. + negotiated string + // period defines how often a "ping" heartbeat message is sent to the other endpoint. + heartbeatPeriod time.Duration + // deadline defines the amount of time before "pong" response must be received. + heartbeatDeadline time.Duration +} + +func NewWebSocketExecutor(config *restclient.Config, method, url string) (Executor, error) { + // Only supports V5 protocol for correct version skew functionality. + // Previous api servers will proxy upgrade requests to legacy websocket + // servers on container runtimes which support V1-V4. These legacy + // websocket servers will not handle the new CLOSE signal. + return NewWebSocketExecutorForProtocols(config, method, url, remotecommand.StreamProtocolV5Name) +} + +// NewWebSocketExecutorForProtocols allows to execute commands via a WebSocket connection. +func NewWebSocketExecutorForProtocols(config *restclient.Config, method, url string, protocols ...string) (Executor, error) { + transport, upgrader, err := websocket.RoundTripperFor(config) + if err != nil { + return nil, fmt.Errorf("error creating websocket transports: %v", err) + } + return &wsStreamExecutor{ + transport: transport, + upgrader: upgrader, + method: method, + url: url, + protocols: protocols, + heartbeatPeriod: pingPeriod, + heartbeatDeadline: pingReadDeadline, + }, nil +} + +// Deprecated: use StreamWithContext instead to avoid possible resource leaks. +// See https://github.com/kubernetes/kubernetes/pull/103177 for details. +func (e *wsStreamExecutor) Stream(options StreamOptions) error { + return e.StreamWithContext(context.Background(), options) +} + +// StreamWithContext upgrades an HTTPRequest to a WebSocket connection, and starts the various +// goroutines to implement the necessary streams over the connection. The "options" parameter +// defines which streams are requested. Returns an error if one occurred. This method is NOT +// safe to run concurrently with the same executor (because of the state stored in the upgrader). +func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + req, err := http.NewRequestWithContext(ctx, e.method, e.url, nil) + if err != nil { + return err + } + conn, err := websocket.Negotiate(e.transport, e.upgrader, req, e.protocols...) + if err != nil { + return err + } + if conn == nil { + panic(fmt.Errorf("websocket connection is nil")) + } + defer conn.Close() + e.negotiated = conn.Subprotocol() + logger := klog.FromContext(ctx) + logger.V(4).Info("Subprotocol negotiated", "protocol", e.negotiated) + + var streamer streamProtocolHandler + switch e.negotiated { + case remotecommand.StreamProtocolV5Name: + streamer = newStreamProtocolV5(options) + case remotecommand.StreamProtocolV4Name: + streamer = newStreamProtocolV4(options) + case remotecommand.StreamProtocolV3Name: + streamer = newStreamProtocolV3(options) + case remotecommand.StreamProtocolV2Name: + streamer = newStreamProtocolV2(options) + case "": + logger.V(4).Info("The server did not negotiate a streaming protocol version, falling back", "protocol", remotecommand.StreamProtocolV1Name) + fallthrough + case remotecommand.StreamProtocolV1Name: + streamer = newStreamProtocolV1(options) + } + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + readyChan := make(chan struct{}) + creator := newWSStreamCreator(logger, conn) + go func() { + select { + // Wait until all streams have been created before starting the readDemuxLoop. + // This is to avoid a race condition where the readDemuxLoop receives a message + // for a stream that has not yet been created. + case <-readyChan: + case <-ctx.Done(): + creator.closeAllStreamReaders(ctx.Err()) + return + } + + creator.readDemuxLoop( + e.upgrader.DataBufferSize(), + e.heartbeatPeriod, + e.heartbeatDeadline, + ) + }() + errorChan <- streamer.stream(logger, creator, readyChan) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +type wsStreamCreator struct { + logger klog.Logger + conn *gwebsocket.Conn + // Protects writing to websocket connection; reading is lock-free + connWriteLock sync.Mutex + // map of stream id to stream; multiple streams read/write the connection + streams map[byte]*stream + streamsMu sync.Mutex + // setStreamErr holds the error to return to anyone calling setStreams. + // this is populated in closeAllStreamReaders + setStreamErr error +} + +func newWSStreamCreator(logger klog.Logger, conn *gwebsocket.Conn) *wsStreamCreator { + return &wsStreamCreator{ + logger: logger, + conn: conn, + streams: map[byte]*stream{}, + } +} + +func (c *wsStreamCreator) getStream(id byte) *stream { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + return c.streams[id] +} + +func (c *wsStreamCreator) setStream(id byte, s *stream) error { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + if c.setStreamErr != nil { + return c.setStreamErr + } + c.streams[id] = s + return nil +} + +// CreateStream uses id from passed headers to create a stream over "c.conn" connection. +// Returns a Stream structure or nil and an error if one occurred. +func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) { + streamType := headers.Get(v1.StreamType) + id, ok := streamType2streamID[streamType] + if !ok { + return nil, fmt.Errorf("unknown stream type: %s", streamType) + } + if s := c.getStream(id); s != nil { + return nil, fmt.Errorf("duplicate stream for type %s", streamType) + } + reader, writer := io.Pipe() + s := &stream{ + logger: klog.LoggerWithValues(c.logger, "id", id), + headers: headers, + readPipe: reader, + writePipe: writer, + conn: c.conn, + connWriteLock: &c.connWriteLock, + id: id, + } + if err := c.setStream(id, s); err != nil { + _ = s.writePipe.Close() + _ = s.readPipe.Close() + return nil, err + } + return s, nil +} + +// readDemuxLoop is the lock-free reading processor for this endpoint of the websocket +// connection. This loop reads the connection, and demultiplexes the data +// into one of the individual stream pipes (by checking the stream id). This +// loop can *not* be run concurrently, because there can only be one websocket +// connection reader at a time (a read mutex would provide no benefit). +func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, deadline time.Duration) { + // Initialize and start the ping/pong heartbeat. + h := newHeartbeat(c.logger, c.conn, period, deadline) + // Set initial timeout for websocket connection reading. + c.logger.V(5).Info("Websocket read starts", "deadline", deadline) + if err := c.conn.SetReadDeadline(time.Now().Add(deadline)); err != nil { + c.logger.Error(err, "Websocket initial setting read deadline failed") + return + } + go h.start() + // Buffer size must correspond to the same size allocated + // for the read buffer during websocket client creation. A + // difference can cause incomplete connection reads. + readBuffer := make([]byte, bufferSize) + for { + // NextReader() only returns data messages (BinaryMessage or Text + // Message). Even though this call will never return control frames + // such as ping, pong, or close, this call is necessary for these + // message types to be processed. There can only be one reader + // at a time, so this reader loop must *not* be run concurrently; + // there is no lock for reading. Calling "NextReader()" before the + // current reader has been processed will close the current reader. + // If the heartbeat read deadline times out, this "NextReader()" will + // return an i/o error, and error handling will clean up. + messageType, r, err := c.conn.NextReader() + if err != nil { + websocketErr, ok := err.(*gwebsocket.CloseError) + if ok && websocketErr.Code == gwebsocket.CloseNormalClosure { + err = nil // readers will get io.EOF as it's a normal closure + } else { + err = fmt.Errorf("next reader: %w", err) + } + c.closeAllStreamReaders(err) + return + } + // All remote command protocols send/receive only binary data messages. + if messageType != gwebsocket.BinaryMessage { + c.closeAllStreamReaders(fmt.Errorf("unexpected message type: %d", messageType)) + return + } + // It's ok to read just a single byte because the underlying library wraps the actual + // connection with a buffered reader anyway. + _, err = io.ReadFull(r, readBuffer[:1]) + if err != nil { + c.closeAllStreamReaders(fmt.Errorf("read stream id: %w", err)) + return + } + streamID := readBuffer[0] + s := c.getStream(streamID) + if s == nil { + c.logger.Error(nil, "Unknown stream, discarding message", "id", streamID) + continue + } + for { + nr, errRead := r.Read(readBuffer) + if nr > 0 { + // Write the data to the stream's pipe. This can block. + _, errWrite := s.writePipe.Write(readBuffer[:nr]) + if errWrite != nil { + // Pipe must have been closed by the stream user. + // Nothing to do, discard the message. + break + } + } + if errRead != nil { + if errRead == io.EOF { + break + } + c.closeAllStreamReaders(fmt.Errorf("read message: %w", errRead)) + return + } + } + } +} + +// closeAllStreamReaders closes readers in all streams. +// This unblocks all stream.Read() calls, and keeps any future streams from being created. +func (c *wsStreamCreator) closeAllStreamReaders(err error) { + c.streamsMu.Lock() + defer c.streamsMu.Unlock() + for _, s := range c.streams { + // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. + _ = s.writePipe.CloseWithError(err) + } + // ensure callers to setStreams receive an error after this point + if err != nil { + c.setStreamErr = err + } else { + c.setStreamErr = fmt.Errorf("closed all streams") + } +} + +type stream struct { + logger klog.Logger + headers http.Header + readPipe *io.PipeReader + writePipe *io.PipeWriter + // conn is used for writing directly into the connection. + // Is nil after Close() / Reset() to prevent future writes. + conn *gwebsocket.Conn + // connWriteLock protects conn against concurrent write operations. There must be a single writer and a single reader only. + // The mutex is shared across all streams because the underlying connection is shared. + connWriteLock *sync.Mutex + id byte +} + +func (s *stream) Read(p []byte) (n int, err error) { + return s.readPipe.Read(p) +} + +// Write writes directly to the underlying WebSocket connection. +func (s *stream) Write(p []byte) (n int, err error) { + s.logger.V(8).Info("Write() on stream") + defer s.logger.V(8).Info("Write() done on stream") + s.connWriteLock.Lock() + defer s.connWriteLock.Unlock() + if s.conn == nil { + return 0, fmt.Errorf("write on closed stream %d", s.id) + } + err = s.conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if err != nil { + s.logger.V(4).Info("Websocket setting write deadline failed", "err", err) + return 0, err + } + // Message writer buffers the message data, so we don't need to do that ourselves. + // Just write id and the data as two separate writes to avoid allocating an intermediate buffer. + w, err := s.conn.NextWriter(gwebsocket.BinaryMessage) + if err != nil { + return 0, err + } + defer func() { + if w != nil { + w.Close() + } + }() + _, err = w.Write([]byte{s.id}) + if err != nil { + return 0, err + } + n, err = w.Write(p) + if err != nil { + return n, err + } + err = w.Close() + w = nil + return n, err +} + +// Close half-closes the stream, indicating this side is finished with the stream. +func (s *stream) Close() error { + s.logger.V(6).Info("Close() on stream") + defer s.logger.V(6).Info("Close() done on stream") + s.connWriteLock.Lock() + defer s.connWriteLock.Unlock() + if s.conn == nil { + return fmt.Errorf("Close() on already closed stream %d", s.id) + } + // Communicate the CLOSE stream signal to the other websocket endpoint. + err := s.conn.WriteMessage(gwebsocket.BinaryMessage, []byte{remotecommand.StreamClose, s.id}) + s.conn = nil + return err +} + +func (s *stream) Reset() error { + s.logger.V(4).Info("Reset() on stream") + defer s.logger.V(4).Info("Reset() done on stream") + s.Close() + return s.writePipe.Close() +} + +func (s *stream) Headers() http.Header { + return s.headers +} + +func (s *stream) Identifier() uint32 { + return uint32(s.id) +} + +// heartbeat encasulates data necessary for the websocket ping/pong heartbeat. This +// heartbeat works by setting a read deadline on the websocket connection, then +// pushing this deadline into the future for every successful heartbeat. If the +// heartbeat "pong" fails to respond within the deadline, then the "NextReader()" call +// inside the "readDemuxLoop" will return an i/o error prompting a connection close +// and cleanup. +type heartbeat struct { + logger klog.Logger + conn *gwebsocket.Conn + // period defines how often a "ping" heartbeat message is sent to the other endpoint + period time.Duration + // closing the "closer" channel will clean up the heartbeat timers + closer chan struct{} + // optional data to send with "ping" message + message []byte + // optionally received data message with "pong" message, same as sent with ping + pongMessage []byte +} + +// newHeartbeat creates heartbeat structure encapsulating fields necessary to +// run the websocket connection ping/pong mechanism and sets up handlers on +// the websocket connection. +func newHeartbeat(logger klog.Logger, conn *gwebsocket.Conn, period time.Duration, deadline time.Duration) *heartbeat { + h := &heartbeat{ + logger: logger, + conn: conn, + period: period, + closer: make(chan struct{}), + } + // Set up handler for receiving returned "pong" message from other endpoint + // by pushing the read deadline into the future. The "msg" received could + // be empty. + h.conn.SetPongHandler(func(msg string) error { + // Push the read deadline into the future. + logger.V(6).Info("Pong message received -- resetting read deadline", "message", msg) + err := h.conn.SetReadDeadline(time.Now().Add(deadline)) + if err != nil { + logger.Error(err, "Websocket setting read deadline failed") + return err + } + if len(msg) > 0 { + h.pongMessage = []byte(msg) + } + return nil + }) + // Set up handler to cleanup timers when this endpoint receives "Close" message. + closeHandler := h.conn.CloseHandler() + h.conn.SetCloseHandler(func(code int, text string) error { + close(h.closer) + return closeHandler(code, text) + }) + return h +} + +// setMessage is optional data sent with "ping" heartbeat. According to the websocket RFC +// this data sent with "ping" message should be returned in "pong" message. +func (h *heartbeat) setMessage(msg string) { + h.message = []byte(msg) +} + +// start the heartbeat by setting up necesssary handlers and looping by sending "ping" +// message every "period" until the "closer" channel is closed. +func (h *heartbeat) start() { + // Loop to continually send "ping" message through websocket connection every "period". + t := time.NewTicker(h.period) + defer t.Stop() + for { + select { + case <-h.closer: + h.logger.V(5).Info("Closed channel -- returning") + return + case <-t.C: + // "WriteControl" does not need to be protected by a mutex. According to + // gorilla/websockets library docs: "The Close and WriteControl methods can + // be called concurrently with all other methods." + if err := h.conn.WriteControl(gwebsocket.PingMessage, h.message, time.Now().Add(pingReadDeadline)); err == nil { + h.logger.V(6).Info("Websocket Ping succeeeded") + } else { + h.logger.Error(err, "Websocket Ping failed") + if errors.Is(err, gwebsocket.ErrCloseSent) { + // we continue because c.conn.CloseChan will manage closing the connection already + continue + } else if e, ok := err.(net.Error); ok && e.Timeout() { + // Continue, in case this is a transient failure. + // c.conn.CloseChan above will tell us when the connection is + // actually closed. + // If Temporary function hadn't been deprecated, we would have used it. + // But most of temporary errors are timeout errors anyway. + continue + } + return + } + } + } +} diff --git a/vendor/k8s.io/client-go/transport/spdy/spdy.go b/vendor/k8s.io/client-go/transport/spdy/spdy.go new file mode 100644 index 000000000..3bb04b77f --- /dev/null +++ b/vendor/k8s.io/client-go/transport/spdy/spdy.go @@ -0,0 +1,317 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spdy + +import ( + "fmt" + "net/http" + "net/url" + "time" + + "k8s.io/apimachinery/pkg/util/httpstream" + httpstreamspdy "k8s.io/apimachinery/pkg/util/httpstream/spdy" + restclient "k8s.io/client-go/rest" + "k8s.io/klog/v2" + streamhttp "k8s.io/streaming/pkg/httpstream" +) + +// Upgrader validates a response from the server after a SPDY upgrade. +type Upgrader interface { + // NewConnection validates the response and creates a new Connection. + NewConnection(resp *http.Response) (httpstream.Connection, error) +} + +// RoundTripperFor returns a round tripper and upgrader to use with SPDY. +func RoundTripperFor(config *restclient.Config) (http.RoundTripper, Upgrader, error) { + tlsConfig, err := restclient.TLSConfigFor(config) + if err != nil { + return nil, nil, err + } + proxy := http.ProxyFromEnvironment + if config.Proxy != nil { + proxy = config.Proxy + } + upgradeRoundTripper, err := httpstreamspdy.NewRoundTripperWithConfig(httpstreamspdy.RoundTripperConfig{ + TLS: tlsConfig, + Proxier: proxy, + PingPeriod: time.Second * 5, + UpgradeTransport: nil, + }) + if err != nil { + return nil, nil, err + } + wrapper, err := restclient.HTTPWrappersForConfig(config, upgradeRoundTripper) + if err != nil { + return nil, nil, err + } + return wrapper, upgradeRoundTripper, nil +} + +// dialer implements the httpstream.Dialer interface. +type dialer struct { + client *http.Client + upgrader Upgrader + method string + url *url.URL +} + +var _ httpstream.Dialer = &dialer{} + +// NewDialer will create a dialer that connects to the provided URL and upgrades the connection to SPDY. +func NewDialer(upgrader Upgrader, client *http.Client, method string, url *url.URL) httpstream.Dialer { + return &dialer{ + client: client, + upgrader: upgrader, + method: method, + url: url, + } +} + +// NewDialerForStreaming creates a SPDY dialer for in-tree callers that use +// k8s.io/streaming/pkg/httpstream types. +func NewDialerForStreaming(upgrader Upgrader, client *http.Client, method string, url *url.URL) streamhttp.Dialer { + return &streamingDialerAdapter{delegate: NewDialer(upgrader, client, method, url)} +} + +// NewUpgraderForStreaming adapts a streaming upgrader for callers that need +// the compatibility Upgrader interface. +func NewUpgraderForStreaming(upgrader streamhttp.UpgradeRoundTripper) Upgrader { + return &compatUpgraderAdapter{delegate: upgrader} +} + +func (d *dialer) Dial(protocols ...string) (httpstream.Connection, string, error) { + req, err := http.NewRequest(d.method, d.url.String(), nil) + if err != nil { + return nil, "", fmt.Errorf("error creating request: %v", err) + } + return Negotiate(d.upgrader, d.client, req, protocols...) +} + +// Negotiate opens a connection to a remote server and attempts to negotiate +// a SPDY connection. Upon success, it returns the connection and the protocol selected by +// the server. The client transport must use the upgradeRoundTripper - see RoundTripperFor. +func Negotiate(upgrader Upgrader, client *http.Client, req *http.Request, protocols ...string) (httpstream.Connection, string, error) { + for i := range protocols { + req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i]) + } + resp, err := client.Do(req) + if err != nil { + return nil, "", fmt.Errorf("error sending request: %v", err) + } + defer resp.Body.Close() + conn, err := upgrader.NewConnection(resp) + if err != nil { + return nil, "", err + } + return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil +} + +// NegotiateStreaming is for in-tree callers that still operate on +// k8s.io/streaming/pkg/httpstream types. +func NegotiateStreaming(upgrader Upgrader, client *http.Client, req *http.Request, protocols ...string) (streamhttp.Connection, string, error) { + conn, protocol, err := Negotiate(upgrader, client, req, protocols...) + if err != nil { + return nil, "", err + } + return wrapStreamingConnection(conn), protocol, nil +} + +type streamingDialerAdapter struct { + delegate httpstream.Dialer +} + +func (d *streamingDialerAdapter) Dial(protocols ...string) (streamhttp.Connection, string, error) { + conn, protocol, err := d.delegate.Dial(protocols...) + if err != nil { + return nil, "", err + } + return wrapStreamingConnection(conn), protocol, nil +} + +type compatUpgraderAdapter struct { + delegate streamhttp.UpgradeRoundTripper +} + +func (u *compatUpgraderAdapter) NewConnection(resp *http.Response) (httpstream.Connection, error) { + conn, err := u.delegate.NewConnection(resp) + if err != nil { + return nil, err + } + return wrapCompatConnection(conn), nil +} + +type streamingStreamAdapter struct { + delegate httpstream.Stream +} + +func (s *streamingStreamAdapter) Read(p []byte) (int, error) { + return s.delegate.Read(p) +} + +func (s *streamingStreamAdapter) Write(p []byte) (int, error) { + return s.delegate.Write(p) +} + +func (s *streamingStreamAdapter) Close() error { + return s.delegate.Close() +} + +func (s *streamingStreamAdapter) Reset() error { + return s.delegate.Reset() +} + +func (s *streamingStreamAdapter) Headers() http.Header { + return s.delegate.Headers() +} + +func (s *streamingStreamAdapter) Identifier() uint32 { + return s.delegate.Identifier() +} + +type streamingConnectionAdapter struct { + delegate httpstream.Connection +} + +func (c *streamingConnectionAdapter) CreateStream(headers http.Header) (streamhttp.Stream, error) { + stream, err := c.delegate.CreateStream(headers) + if err != nil { + return nil, err + } + return &streamingStreamAdapter{delegate: stream}, nil +} + +func (c *streamingConnectionAdapter) Close() error { + return c.delegate.Close() +} + +func (c *streamingConnectionAdapter) CloseChan() <-chan bool { + return c.delegate.CloseChan() +} + +func (c *streamingConnectionAdapter) SetIdleTimeout(timeout time.Duration) { + c.delegate.SetIdleTimeout(timeout) +} + +func (c *streamingConnectionAdapter) RemoveStreams(streams ...streamhttp.Stream) { + compatStreams := make([]httpstream.Stream, 0, len(streams)) + for _, stream := range streams { + if stream == nil { + continue + } + if s, ok := stream.(*streamingStreamAdapter); ok { + compatStreams = append(compatStreams, s.delegate) + continue + } + if s, ok := stream.(httpstream.Stream); ok { + compatStreams = append(compatStreams, s) + continue + } + klog.V(5).Infof("dropping unadaptable streaming stream %T in RemoveStreams", stream) + } + c.delegate.RemoveStreams(compatStreams...) +} + +func wrapStreamingConnection(conn httpstream.Connection) streamhttp.Connection { + if conn == nil { + return nil + } + if wrapped, ok := conn.(*compatConnectionAdapter); ok { + return wrapped.delegate + } + return &streamingConnectionAdapter{delegate: conn} +} + +type compatStreamAdapter struct { + delegate streamhttp.Stream +} + +func (s *compatStreamAdapter) Read(p []byte) (int, error) { + return s.delegate.Read(p) +} + +func (s *compatStreamAdapter) Write(p []byte) (int, error) { + return s.delegate.Write(p) +} + +func (s *compatStreamAdapter) Close() error { + return s.delegate.Close() +} + +func (s *compatStreamAdapter) Reset() error { + return s.delegate.Reset() +} + +func (s *compatStreamAdapter) Headers() http.Header { + return s.delegate.Headers() +} + +func (s *compatStreamAdapter) Identifier() uint32 { + return s.delegate.Identifier() +} + +type compatConnectionAdapter struct { + delegate streamhttp.Connection +} + +func (c *compatConnectionAdapter) CreateStream(headers http.Header) (httpstream.Stream, error) { + stream, err := c.delegate.CreateStream(headers) + if err != nil { + return nil, err + } + return &compatStreamAdapter{delegate: stream}, nil +} + +func (c *compatConnectionAdapter) Close() error { + return c.delegate.Close() +} + +func (c *compatConnectionAdapter) CloseChan() <-chan bool { + return c.delegate.CloseChan() +} + +func (c *compatConnectionAdapter) SetIdleTimeout(timeout time.Duration) { + c.delegate.SetIdleTimeout(timeout) +} + +func (c *compatConnectionAdapter) RemoveStreams(streams ...httpstream.Stream) { + streamingStreams := make([]streamhttp.Stream, 0, len(streams)) + for _, stream := range streams { + if stream == nil { + continue + } + if s, ok := stream.(*compatStreamAdapter); ok { + streamingStreams = append(streamingStreams, s.delegate) + continue + } + if s, ok := stream.(streamhttp.Stream); ok { + streamingStreams = append(streamingStreams, s) + continue + } + klog.V(5).Infof("dropping unadaptable compat stream %T in RemoveStreams", stream) + } + c.delegate.RemoveStreams(streamingStreams...) +} + +func wrapCompatConnection(conn streamhttp.Connection) httpstream.Connection { + if conn == nil { + return nil + } + if wrapped, ok := conn.(*streamingConnectionAdapter); ok { + return wrapped.delegate + } + return &compatConnectionAdapter{delegate: conn} +} diff --git a/vendor/k8s.io/client-go/transport/websocket/roundtripper.go b/vendor/k8s.io/client-go/transport/websocket/roundtripper.go new file mode 100644 index 000000000..5285d6b14 --- /dev/null +++ b/vendor/k8s.io/client-go/transport/websocket/roundtripper.go @@ -0,0 +1,224 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package websocket + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + gwebsocket "github.com/gorilla/websocket" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" + utilnet "k8s.io/apimachinery/pkg/util/net" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/transport" + "k8s.io/streaming/pkg/httpstream" + "k8s.io/streaming/pkg/httpstream/wsstream" +) + +var ( + _ utilnet.TLSClientConfigHolder = &RoundTripper{} + _ http.RoundTripper = &RoundTripper{} +) + +var ( + statusScheme = runtime.NewScheme() + statusCodecs = serializer.NewCodecFactory(statusScheme) +) + +func init() { + statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion, + &metav1.Status{}, + ) +} + +// ConnectionHolder defines functions for structure providing +// access to the websocket connection. +type ConnectionHolder interface { + DataBufferSize() int + Connection() *gwebsocket.Conn +} + +// RoundTripper knows how to establish a connection to a remote WebSocket endpoint and make it available for use. +// RoundTripper must not be reused. +type RoundTripper struct { + // TLSConfig holds the TLS configuration settings to use when connecting + // to the remote server. + TLSConfig *tls.Config + + // Proxier specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxier func(req *http.Request) (*url.URL, error) + + // Conn holds the WebSocket connection after a round trip. + Conn *gwebsocket.Conn +} + +// Connection returns the stored websocket connection. +func (rt *RoundTripper) Connection() *gwebsocket.Conn { + return rt.Conn +} + +// DataBufferSize returns the size of buffers for the +// websocket connection. +func (rt *RoundTripper) DataBufferSize() int { + return 32 * 1024 +} + +// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder. +func (rt *RoundTripper) TLSClientConfig() *tls.Config { + return rt.TLSConfig +} + +// RoundTrip connects to the remote websocket using the headers in the request and the TLS +// configuration from the config +func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response, retErr error) { + defer func() { + if request.Body != nil { + err := request.Body.Close() + if retErr == nil { + retErr = err + } + } + }() + + // set the protocol version directly on the dialer from the header + protocolVersions := request.Header[wsstream.WebSocketProtocolHeader] + delete(request.Header, wsstream.WebSocketProtocolHeader) + + dialer := gwebsocket.Dialer{ + Proxy: rt.Proxier, + TLSClientConfig: rt.TLSConfig, + Subprotocols: protocolVersions, + ReadBufferSize: rt.DataBufferSize() + 1024, // add space for the protocol byte indicating which channel the data is for + WriteBufferSize: rt.DataBufferSize() + 1024, // add space for the protocol byte indicating which channel the data is for + } + switch request.URL.Scheme { + case "https": + request.URL.Scheme = "wss" + case "http": + request.URL.Scheme = "ws" + default: + return nil, fmt.Errorf("unknown url scheme: %s", request.URL.Scheme) + } + wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) + if err != nil { + // BadHandshake error becomes an "UpgradeFailureError" (used for streaming fallback). + if errors.Is(err, gwebsocket.ErrBadHandshake) { + cause := err + // Enhance the error message with the error response if possible. + if resp != nil && len(resp.Status) > 0 { + defer resp.Body.Close() //nolint:errcheck + cause = fmt.Errorf("%w (%s)", err, resp.Status) // Always add the response status + responseError := "" + responseErrorBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if readErr != nil { + cause = fmt.Errorf("%w: unable to read error from server response", cause) + } else { + // If returned error can be decoded as "metav1.Status", return a "StatusError". + responseError = strings.TrimSpace(string(responseErrorBytes)) + if len(responseError) > 0 { + if obj, _, decodeErr := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); decodeErr == nil { + if status, ok := obj.(*metav1.Status); ok { + cause = &apierrors.StatusError{ErrStatus: *status} + } + } else { + // Otherwise, append the responseError string. + cause = fmt.Errorf("%w: %s", cause, responseError) + } + } + } + } + return nil, &httpstream.UpgradeFailureError{Cause: cause} + } + return nil, err + } + + // Ensure we got back a protocol we understand + foundProtocol := false + for _, protocolVersion := range protocolVersions { + if protocolVersion == wsConn.Subprotocol() { + foundProtocol = true + break + } + } + if !foundProtocol { + wsConn.Close() // nolint:errcheck + return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())} + } + + rt.Conn = wsConn + + return resp, nil +} + +// RoundTripperFor transforms the passed rest config into a wrapped roundtripper, as well +// as a pointer to the websocket RoundTripper. The websocket RoundTripper contains the +// websocket connection after RoundTrip() on the wrapper. Returns an error if there is +// a problem creating the round trippers. +func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHolder, error) { + transportCfg, err := config.TransportConfig() + if err != nil { + return nil, nil, err + } + tlsConfig, err := transport.TLSConfigFor(transportCfg) + if err != nil { + return nil, nil, err + } + proxy := config.Proxy + if proxy == nil { + proxy = utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) + } + + upgradeRoundTripper := &RoundTripper{ + TLSConfig: tlsConfig, + Proxier: proxy, + } + wrapper, err := transport.HTTPWrappersForConfig(transportCfg, upgradeRoundTripper) + if err != nil { + return nil, nil, err + } + return wrapper, upgradeRoundTripper, nil +} + +// Negotiate opens a connection to a remote server and attempts to negotiate +// a WebSocket connection. Upon success, it returns the negotiated connection. +// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor. +func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) { + // Plumb protocols to RoundTripper#RoundTrip + req.Header[wsstream.WebSocketProtocolHeader] = protocols + resp, err := rt.RoundTrip(req) + if err != nil { + return nil, err + } + err = resp.Body.Close() + if err != nil { + connectionInfo.Connection().Close() + return nil, fmt.Errorf("error closing response body: %v", err) + } + return connectionInfo.Connection(), nil +} diff --git a/vendor/k8s.io/client-go/util/exec/exec.go b/vendor/k8s.io/client-go/util/exec/exec.go new file mode 100644 index 000000000..d170badb6 --- /dev/null +++ b/vendor/k8s.io/client-go/util/exec/exec.go @@ -0,0 +1,52 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec + +// ExitError is an interface that presents an API similar to os.ProcessState, which is +// what ExitError from os/exec is. This is designed to make testing a bit easier and +// probably loses some of the cross-platform properties of the underlying library. +type ExitError interface { + String() string + Error() string + Exited() bool + ExitStatus() int +} + +// CodeExitError is an implementation of ExitError consisting of an error object +// and an exit code (the upper bits of os.exec.ExitStatus). +type CodeExitError struct { + Err error + Code int +} + +var _ ExitError = CodeExitError{} + +func (e CodeExitError) Error() string { + return e.Err.Error() +} + +func (e CodeExitError) String() string { + return e.Err.Error() +} + +func (e CodeExitError) Exited() bool { + return true +} + +func (e CodeExitError) ExitStatus() int { + return e.Code +} diff --git a/vendor/k8s.io/streaming/LICENSE b/vendor/k8s.io/streaming/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/vendor/k8s.io/streaming/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/k8s.io/streaming/pkg/httpstream/doc.go b/vendor/k8s.io/streaming/pkg/httpstream/doc.go new file mode 100644 index 000000000..1da83f14b --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/doc.go @@ -0,0 +1,19 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package httpstream adds multiplexed streaming support to HTTP requests and +// responses via connection upgrades. +package httpstream diff --git a/vendor/k8s.io/streaming/pkg/httpstream/httpstream.go b/vendor/k8s.io/streaming/pkg/httpstream/httpstream.go new file mode 100644 index 000000000..a7c8d897d --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/httpstream.go @@ -0,0 +1,201 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package httpstream + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions" +) + +// NewStreamHandler defines a function that is called when a new Stream is +// received. If no error is returned, the Stream is accepted; otherwise, +// the stream is rejected. After the reply frame has been sent, replySent is closed. +type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error + +// NoOpNewStreamHandler is a stream handler that accepts a new stream and +// performs no other logic. +func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil } + +// Dialer knows how to open a streaming connection to a server. +type Dialer interface { + + // Dial opens a streaming connection to a server using one of the protocols + // specified (in order of most preferred to least preferred). + Dial(protocols ...string) (Connection, string, error) +} + +// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade +// HTTP requests to support multiplexed bidirectional streams. After RoundTrip() +// is invoked, if the upgrade is successful, clients may retrieve the upgraded +// connection by calling UpgradeRoundTripper.Connection(). +type UpgradeRoundTripper interface { + http.RoundTripper + // NewConnection validates the response and creates a new Connection. + NewConnection(resp *http.Response) (Connection, error) +} + +// ResponseUpgrader knows how to upgrade HTTP requests and responses to +// add streaming support to them. +type ResponseUpgrader interface { + // UpgradeResponse upgrades an HTTP response to one that supports multiplexed + // streams. newStreamHandler will be called asynchronously whenever the + // other end of the upgraded connection creates a new stream. + UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection +} + +// Connection represents an upgraded HTTP connection. +type Connection interface { + // CreateStream creates a new Stream with the supplied headers. + CreateStream(headers http.Header) (Stream, error) + // Close resets all streams and closes the connection. + Close() error + // CloseChan returns a channel that is closed when the underlying connection is closed. + CloseChan() <-chan bool + // SetIdleTimeout sets the amount of time the connection may remain idle before + // it is automatically closed. + SetIdleTimeout(timeout time.Duration) + // RemoveStreams can be used to remove a set of streams from the Connection. + RemoveStreams(streams ...Stream) +} + +// Stream represents a bidirectional communications channel that is part of an +// upgraded connection. +type Stream interface { + io.ReadWriteCloser + // Reset closes both directions of the stream, indicating that neither client + // or server can use it any more. + Reset() error + // Headers returns the headers used to create the stream. + Headers() http.Header + // Identifier returns the stream's ID. + Identifier() uint32 +} + +// UpgradeFailureError encapsulates the cause for why the streaming +// upgrade request failed. Implements error interface. +type UpgradeFailureError struct { + Cause error +} + +func (u *UpgradeFailureError) Error() string { + return fmt.Sprintf("unable to upgrade streaming request: %s", u.Cause) +} + +// IsUpgradeFailure returns true if the passed error is (or wrapped error contains) +// the UpgradeFailureError. +func IsUpgradeFailure(err error) bool { + if err == nil { + return false + } + var upgradeErr *UpgradeFailureError + return errors.As(err, &upgradeErr) +} + +// isHTTPSProxyError returns true if error is Gorilla/Websockets HTTPS Proxy dial error; +// false otherwise (see https://github.com/kubernetes/kubernetes/issues/126134). +func IsHTTPSProxyError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "proxy: unknown scheme: https") +} + +// IsUpgradeRequest returns true if the given request is a connection upgrade request +func IsUpgradeRequest(req *http.Request) bool { + for _, h := range req.Header[http.CanonicalHeaderKey(HeaderConnection)] { + if strings.Contains(strings.ToLower(h), strings.ToLower(HeaderUpgrade)) { + return true + } + } + return false +} + +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + +func commaSeparatedHeaderValues(header []string) []string { + var parsedClientProtocols []string + for i := range header { + for _, clientProtocol := range strings.Split(header[i], ",") { + if proto := strings.Trim(clientProtocol, " "); len(proto) > 0 { + parsedClientProtocols = append(parsedClientProtocols, proto) + } + } + } + return parsedClientProtocols +} + +// Handshake performs a subprotocol negotiation. If the client did request a +// subprotocol, Handshake will select the first common value found in +// serverProtocols, otherwise it will return an error and write an HTTP BadRequest to the response. +// If a match is found, Handshake adds a response header indicating the chosen subprotocol. +// If no match is found, HTTP forbidden is returned, along with a response header containing +// the list of protocols the server can accept. +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) { + if len(serverProtocols) == 0 { + panic(fmt.Errorf("unable to upgrade: serverProtocols is required")) + } + values, ok := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] + if !ok { + err := fmt.Errorf("unable to upgrade: header %s does not exist in request with %d headers", HeaderProtocolVersion, len(req.Header)) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + if len(values) == 0 { + err := fmt.Errorf("unable to upgrade: header %s is empty", HeaderProtocolVersion) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + clientProtocols := commaSeparatedHeaderValues(values) + if len(clientProtocols) == 0 { + err := fmt.Errorf("unable to upgrade: header %s contains %s, but no valid protocols", HeaderProtocolVersion, values) + http.Error(w, err.Error(), http.StatusBadRequest) + return "", err + } + + negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) + if len(negotiatedProtocol) == 0 { + for i := range serverProtocols { + w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i]) + } + err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols) + http.Error(w, err.Error(), http.StatusForbidden) + return "", err + } + + w.Header().Add(HeaderProtocolVersion, negotiatedProtocol) + return negotiatedProtocol, nil +} diff --git a/vendor/k8s.io/streaming/pkg/httpstream/spdy/connection.go b/vendor/k8s.io/streaming/pkg/httpstream/spdy/connection.go new file mode 100644 index 000000000..4a4003b62 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/spdy/connection.go @@ -0,0 +1,206 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spdy + +import ( + "net" + "net/http" + "sync" + "time" + + "github.com/moby/spdystream" + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" +) + +// connection maintains state about a spdystream.Connection and its associated +// streams. +type connection struct { + conn *spdystream.Connection + streams map[uint32]httpstream.Stream + streamLock sync.Mutex + newStreamHandler httpstream.NewStreamHandler + ping func() (time.Duration, error) +} + +// NewClientConnection creates a new SPDY client connection. +func NewClientConnection(conn net.Conn) (httpstream.Connection, error) { + return NewClientConnectionWithPings(conn, 0) +} + +// NewClientConnectionWithPings creates a new SPDY client connection. +// +// If pingPeriod is non-zero, a background goroutine will send periodic Ping +// frames to the server. Use this to keep idle connections through certain load +// balancers alive longer. +func NewClientConnectionWithPings(conn net.Conn, pingPeriod time.Duration) (httpstream.Connection, error) { + spdyConn, err := spdystream.NewConnection(conn, false) + if err != nil { + defer conn.Close() + return nil, err + } + + return newConnection(spdyConn, httpstream.NoOpNewStreamHandler, pingPeriod, spdyConn.Ping), nil +} + +// NewServerConnection creates a new SPDY server connection. newStreamHandler +// will be invoked when the server receives a newly created stream from the +// client. +func NewServerConnection(conn net.Conn, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, error) { + return NewServerConnectionWithPings(conn, newStreamHandler, 0) +} + +// NewServerConnectionWithPings creates a new SPDY server connection. +// newStreamHandler will be invoked when the server receives a newly created +// stream from the client. +// +// If pingPeriod is non-zero, a background goroutine will send periodic Ping +// frames to the server. Use this to keep idle connections through certain load +// balancers alive longer. +func NewServerConnectionWithPings(conn net.Conn, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration) (httpstream.Connection, error) { + spdyConn, err := spdystream.NewConnection(conn, true) + if err != nil { + defer conn.Close() + return nil, err + } + + return newConnection(spdyConn, newStreamHandler, pingPeriod, spdyConn.Ping), nil +} + +// newConnection returns a new connection wrapping conn. newStreamHandler +// will be invoked when the server receives a newly created stream from the +// client. +func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration, pingFn func() (time.Duration, error)) httpstream.Connection { + c := &connection{ + conn: conn, + newStreamHandler: newStreamHandler, + ping: pingFn, + streams: make(map[uint32]httpstream.Stream), + } + go conn.Serve(c.newSpdyStream) + if pingPeriod > 0 && pingFn != nil { + go c.sendPings(pingPeriod) + } + return c +} + +// createStreamResponseTimeout indicates how long to wait for the other side to +// acknowledge the new stream before timing out. +const createStreamResponseTimeout = 30 * time.Second + +// Close first sends a reset for all of the connection's streams, and then +// closes the underlying spdystream.Connection. +func (c *connection) Close() error { + c.streamLock.Lock() + for _, s := range c.streams { + // calling Reset instead of Close ensures that all streams are fully torn down + s.Reset() + } + c.streams = make(map[uint32]httpstream.Stream, 0) + c.streamLock.Unlock() + + // now that all streams are fully torn down, it's safe to call close on the underlying connection, + // which should be able to terminate immediately at this point, instead of waiting for any + // remaining graceful stream termination. + return c.conn.Close() +} + +// RemoveStreams can be used to removes a set of streams from the Connection. +func (c *connection) RemoveStreams(streams ...httpstream.Stream) { + c.streamLock.Lock() + for _, stream := range streams { + // It may be possible that the provided stream is nil if timed out. + if stream != nil { + delete(c.streams, stream.Identifier()) + } + } + c.streamLock.Unlock() +} + +// CreateStream creates a new stream with the specified headers and registers +// it with the connection. +func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) { + stream, err := c.conn.CreateStream(headers, nil, false) + if err != nil { + return nil, err + } + if err = stream.WaitTimeout(createStreamResponseTimeout); err != nil { + return nil, err + } + + c.registerStream(stream) + return stream, nil +} + +// registerStream adds the stream s to the connection's list of streams that +// it owns. +func (c *connection) registerStream(s httpstream.Stream) { + c.streamLock.Lock() + c.streams[s.Identifier()] = s + c.streamLock.Unlock() +} + +// CloseChan returns a channel that, when closed, indicates that the underlying +// spdystream.Connection has been closed. +func (c *connection) CloseChan() <-chan bool { + return c.conn.CloseChan() +} + +// newSpdyStream is the internal new stream handler used by spdystream.Connection.Serve. +// It calls connection's newStreamHandler, giving it the opportunity to accept or reject +// the stream. If newStreamHandler returns an error, the stream is rejected. If not, the +// stream is accepted and registered with the connection. +func (c *connection) newSpdyStream(stream *spdystream.Stream) { + replySent := make(chan struct{}) + err := c.newStreamHandler(stream, replySent) + rejectStream := (err != nil) + if rejectStream { + //nolint:logcheck // Hopefully this never gets triggered. + klog.Warningf("Stream rejected: %v", err) + stream.Reset() + return + } + + c.registerStream(stream) + stream.SendReply(http.Header{}, rejectStream) + close(replySent) +} + +// SetIdleTimeout sets the amount of time the connection may remain idle before +// it is automatically closed. +func (c *connection) SetIdleTimeout(timeout time.Duration) { + c.conn.SetIdleTimeout(timeout) +} + +func (c *connection) sendPings(period time.Duration) { + t := time.NewTicker(period) + defer t.Stop() + for { + select { + case <-c.conn.CloseChan(): + return + case <-t.C: + } + if _, err := c.ping(); err != nil { + //nolint:logcheck // Hopefully this never gets triggered. + klog.V(3).Infof("SPDY Ping failed: %v", err) + // Continue, in case this is a transient failure. + // c.conn.CloseChan above will tell us when the connection is + // actually closed. + } + } +} diff --git a/vendor/k8s.io/streaming/pkg/httpstream/spdy/roundtripper.go b/vendor/k8s.io/streaming/pkg/httpstream/spdy/roundtripper.go new file mode 100644 index 000000000..41cd65f64 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/spdy/roundtripper.go @@ -0,0 +1,572 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spdy + +import ( + "bufio" + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "time" + + "golang.org/x/net/proxy" + "k8s.io/streaming/pkg/httpstream" + utilnet "k8s.io/utils/net" +) + +// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports +// multiplexed streams. After RoundTrip() is invoked, Conn will be set +// and usable. SpdyRoundTripper implements the UpgradeRoundTripper interface. +type SpdyRoundTripper struct { + //tlsConfig holds the TLS configuration settings to use when connecting + //to the remote server. + tlsConfig *tls.Config + + /* TODO according to http://golang.org/pkg/net/http/#RoundTripper, a RoundTripper + must be safe for use by multiple concurrent goroutines. If this is absolutely + necessary, we could keep a map from http.Request to net.Conn. In practice, + a client will create an http.Client, set the transport to a new insteace of + SpdyRoundTripper, and use it a single time, so this hopefully won't be an issue. + */ + // conn is the underlying network connection to the remote server. + conn net.Conn + + // Dialer is the dialer used to connect. Used if non-nil. + Dialer *net.Dialer + + // proxier knows which proxy to use given a request, defaults to a proxier that + // preserves NO_PROXY CIDR behavior while delegating to http.ProxyFromEnvironment. + // Used primarily for mocking the proxy discovery in tests. + proxier func(req *http.Request) (*url.URL, error) + + // pingPeriod is a period for sending Ping frames over established + // connections. + pingPeriod time.Duration + + // upgradeTransport is an optional substitute for dialing if present. This field is + // mutually exclusive with the "tlsConfig", "Dialer", and "proxier". + upgradeTransport http.RoundTripper +} + +type tlsClientConfigHolder interface { + TLSClientConfig() *tls.Config +} + +type roundTripperWrapper interface { + http.RoundTripper + WrappedRoundTripper() http.RoundTripper +} + +type dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +var _ tlsClientConfigHolder = &SpdyRoundTripper{} +var _ httpstream.UpgradeRoundTripper = &SpdyRoundTripper{} + +// NewRoundTripper creates a new SpdyRoundTripper that will use the specified +// tlsConfig. +func NewRoundTripper(tlsConfig *tls.Config) (*SpdyRoundTripper, error) { + return NewRoundTripperWithConfig(RoundTripperConfig{ + TLS: tlsConfig, + UpgradeTransport: nil, + }) +} + +// NewRoundTripperWithProxy creates a new SpdyRoundTripper that will use the +// specified tlsConfig and proxy func. +func NewRoundTripperWithProxy(tlsConfig *tls.Config, proxier func(*http.Request) (*url.URL, error)) (*SpdyRoundTripper, error) { + return NewRoundTripperWithConfig(RoundTripperConfig{ + TLS: tlsConfig, + Proxier: proxier, + UpgradeTransport: nil, + }) +} + +// NewRoundTripperWithConfig creates a new SpdyRoundTripper with the specified +// configuration. Returns an error if the SpdyRoundTripper is misconfigured. +func NewRoundTripperWithConfig(cfg RoundTripperConfig) (*SpdyRoundTripper, error) { + // Process UpgradeTransport, which is mutually exclusive to TLSConfig and Proxier. + if cfg.UpgradeTransport != nil { + if cfg.TLS != nil || cfg.Proxier != nil { + return nil, fmt.Errorf("SpdyRoundTripper: UpgradeTransport is mutually exclusive to TLSConfig or Proxier") + } + tlsConfig, err := tlsConfigForTransport(cfg.UpgradeTransport) + if err != nil { + return nil, fmt.Errorf("SpdyRoundTripper: unable to retrieve TLS config from UpgradeTransport: %w", err) + } + cfg.TLS = tlsConfig + } + if cfg.Proxier == nil { + cfg.Proxier = newProxierWithNoProxyCIDR(http.ProxyFromEnvironment) + } + return &SpdyRoundTripper{ + tlsConfig: cfg.TLS, + proxier: cfg.Proxier, + pingPeriod: cfg.PingPeriod, + upgradeTransport: cfg.UpgradeTransport, + }, nil +} + +// newProxierWithNoProxyCIDR preserves CIDR matching in NO_PROXY/no_proxy while +// delegating all other behavior to the supplied proxy function. +func newProxierWithNoProxyCIDR(delegate func(req *http.Request) (*url.URL, error)) func(req *http.Request) (*url.URL, error) { + noProxyEnv := os.Getenv("NO_PROXY") + if noProxyEnv == "" { + noProxyEnv = os.Getenv("no_proxy") + } + noProxyRules := strings.Split(noProxyEnv, ",") + + cidrs := make([]*net.IPNet, 0, len(noProxyRules)) + for _, noProxyRule := range noProxyRules { + noProxyRule = strings.TrimSpace(noProxyRule) + if noProxyRule == "" { + continue + } + _, cidr, err := utilnet.ParseCIDRSloppy(noProxyRule) + if err == nil { + cidrs = append(cidrs, cidr) + } + } + + if len(cidrs) == 0 { + return delegate + } + + return func(req *http.Request) (*url.URL, error) { + ip := utilnet.ParseIPSloppy(req.URL.Hostname()) + if ip == nil { + return delegate(req) + } + + for _, cidr := range cidrs { + if cidr.Contains(ip) { + return nil, nil + } + } + + return delegate(req) + } +} + +// RoundTripperConfig is a set of options for an SpdyRoundTripper. +type RoundTripperConfig struct { + // TLS configuration used by the round tripper if UpgradeTransport not present. + TLS *tls.Config + // Proxier is a proxy function invoked on each request. Optional. + Proxier func(*http.Request) (*url.URL, error) + // PingPeriod is a period for sending SPDY Pings on the connection. + // Optional. + PingPeriod time.Duration + // UpgradeTransport is a subtitute transport used for dialing. If set, + // this field will be used instead of "TLS" and "Proxier" for connection creation. + // Optional. + UpgradeTransport http.RoundTripper +} + +// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during +// proxying with a spdy roundtripper. +func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config { + return s.tlsConfig +} + +// Dial opens a network connection for an upgrade request. +func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { + var conn net.Conn + var err error + if s.upgradeTransport != nil { + conn, err = dialURLWithTransport(req.Context(), req.URL, s.upgradeTransport) + } else { + conn, err = s.dial(req) + } + if err != nil { + return nil, err + } + + if err := req.Write(conn); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +// dial dials the host specified by req, using TLS if appropriate, optionally +// using a proxy server if one is configured via environment variables. +func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) { + proxyURL, err := s.proxier(req) + if err != nil { + return nil, err + } + + if proxyURL == nil { + return s.dialWithoutProxy(req.Context(), req.URL) + } + + switch proxyURL.Scheme { + case "socks5": + return s.dialWithSocks5Proxy(req, proxyURL) + case "https", "http", "": + return s.dialWithHttpProxy(req, proxyURL) + } + + return nil, fmt.Errorf("proxy URL scheme not supported: %s", proxyURL.Scheme) +} + +// dialWithHttpProxy dials the host specified by url through an http or an https proxy. +func (s *SpdyRoundTripper) dialWithHttpProxy(req *http.Request, proxyURL *url.URL) (net.Conn, error) { + // ensure we use a canonical host with proxyReq + targetHost := canonicalAddr(req.URL) + + // proxying logic adapted from http://blog.h6t.eu/post/74098062923/golang-websocket-with-http-proxy-support + proxyReq := http.Request{ + Method: http.MethodConnect, + URL: &url.URL{}, + Host: targetHost, + } + + proxyReq = *proxyReq.WithContext(req.Context()) + + if pa := s.proxyAuth(proxyURL); pa != "" { + proxyReq.Header = http.Header{} + proxyReq.Header.Set("Proxy-Authorization", pa) + } + + proxyDialConn, err := s.dialWithoutProxy(proxyReq.Context(), proxyURL) + if err != nil { + return nil, err + } + + //nolint:staticcheck // SA1019 ignore deprecated httputil.NewProxyClientConn + proxyClientConn := httputil.NewProxyClientConn(proxyDialConn, nil) + response, err := proxyClientConn.Do(&proxyReq) + //nolint:staticcheck // SA1019 ignore deprecated httputil.ErrPersistEOF: it might be + // returned from the invocation of proxyClientConn.Do + if err != nil && err != httputil.ErrPersistEOF { + return nil, err + } + if response != nil && response.StatusCode >= 300 || response.StatusCode < 200 { + return nil, fmt.Errorf("CONNECT request to %s returned response: %s", proxyURL.Redacted(), response.Status) + } + + rwc, _ := proxyClientConn.Hijack() + + if req.URL.Scheme == "https" { + return s.tlsConn(proxyReq.Context(), rwc, targetHost) + } + return rwc, nil +} + +// dialWithSocks5Proxy dials the host specified by url through a socks5 proxy. +func (s *SpdyRoundTripper) dialWithSocks5Proxy(req *http.Request, proxyURL *url.URL) (net.Conn, error) { + // ensure we use a canonical host with proxyReq + targetHost := canonicalAddr(req.URL) + proxyDialAddr := canonicalAddr(proxyURL) + + var auth *proxy.Auth + if proxyURL.User != nil { + pass, _ := proxyURL.User.Password() + auth = &proxy.Auth{ + User: proxyURL.User.Username(), + Password: pass, + } + } + + dialer := s.Dialer + if dialer == nil { + dialer = &net.Dialer{ + Timeout: 30 * time.Second, + } + } + + proxyDialer, err := proxy.SOCKS5("tcp", proxyDialAddr, auth, dialer) + if err != nil { + return nil, err + } + + // According to the implementation of proxy.SOCKS5, the type assertion will always succeed + contextDialer, ok := proxyDialer.(proxy.ContextDialer) + if !ok { + return nil, errors.New("SOCKS5 Dialer must implement ContextDialer") + } + + proxyDialConn, err := contextDialer.DialContext(req.Context(), "tcp", targetHost) + if err != nil { + return nil, err + } + + if req.URL.Scheme == "https" { + return s.tlsConn(req.Context(), proxyDialConn, targetHost) + } + return proxyDialConn, nil +} + +// tlsConn returns a TLS client side connection using rwc as the underlying transport. +func (s *SpdyRoundTripper) tlsConn(ctx context.Context, rwc net.Conn, targetHost string) (net.Conn, error) { + + host, _, err := net.SplitHostPort(targetHost) + if err != nil { + return nil, err + } + + tlsConfig := s.tlsConfig + switch { + case tlsConfig == nil: + tlsConfig = &tls.Config{ServerName: host} + case len(tlsConfig.ServerName) == 0: + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = host + } + + tlsConn := tls.Client(rwc, tlsConfig) + + if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, nil +} + +// dialWithoutProxy dials the host specified by url, using TLS if appropriate. +func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) { + dialAddr := canonicalAddr(url) + dialer := s.Dialer + if dialer == nil { + dialer = &net.Dialer{} + } + + if url.Scheme == "http" { + return dialer.DialContext(ctx, "tcp", dialAddr) + } + + tlsDialer := tls.Dialer{ + NetDialer: dialer, + Config: s.tlsConfig, + } + return tlsDialer.DialContext(ctx, "tcp", dialAddr) +} + +// proxyAuth returns, for a given proxy URL, the value to be used for the Proxy-Authorization header +func (s *SpdyRoundTripper) proxyAuth(proxyURL *url.URL) string { + if proxyURL == nil || proxyURL.User == nil { + return "" + } + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + auth := username + ":" + password + return "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) +} + +// RoundTrip executes the Request and upgrades it. After a successful upgrade, +// clients may call SpdyRoundTripper.Connection() to retrieve the upgraded +// connection. +func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header = req.Header.Clone() + req.Header.Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) + req.Header.Add(httpstream.HeaderUpgrade, HeaderSpdy31) + + conn, err := s.Dial(req) + if err != nil { + return nil, err + } + + responseReader := bufio.NewReader(conn) + + resp, err := http.ReadResponse(responseReader, nil) + if err != nil { + conn.Close() + return nil, err + } + + s.conn = conn + + return resp, nil +} + +// NewConnection validates the upgrade response, creating and returning a new +// httpstream.Connection if there were no errors. +func (s *SpdyRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) { + connectionHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderConnection)) + upgradeHeader := strings.ToLower(resp.Header.Get(httpstream.HeaderUpgrade)) + if (resp.StatusCode != http.StatusSwitchingProtocols) || !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { + defer resp.Body.Close() + responseErrorBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("unable to upgrade connection: unable to read error from server response") + } + return nil, fmt.Errorf("unable to upgrade connection: %s", upgradeErrorMessage(responseErrorBytes)) + } + + return NewClientConnectionWithPings(s.conn, s.pingPeriod) +} + +func tlsConfigForTransport(transport http.RoundTripper) (*tls.Config, error) { + if transport == nil { + return nil, nil + } + switch transport := transport.(type) { + case *http.Transport: + return transport.TLSClientConfig, nil + case tlsClientConfigHolder: + return transport.TLSClientConfig(), nil + case roundTripperWrapper: + return tlsConfigForTransport(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("transport %T does not expose TLS client config", transport) + } +} + +func canonicalAddr(url *url.URL) string { + host := url.Hostname() + port := url.Port() + if len(port) == 0 { + switch strings.ToLower(url.Scheme) { + case "http", "ws": + port = "80" + case "https", "wss": + port = "443" + } + } + return net.JoinHostPort(host, port) +} + +func upgradeErrorMessage(responseErrorBytes []byte) string { + type statusLike struct { + Message string `json:"message"` + Error string `json:"error"` + } + + var status statusLike + if err := json.Unmarshal(responseErrorBytes, &status); err == nil { + if msg := strings.TrimSpace(status.Message); msg != "" { + return msg + } + if msg := strings.TrimSpace(status.Error); msg != "" { + return msg + } + } + + msg := strings.TrimSpace(string(responseErrorBytes)) + if msg == "" { + return "empty server response" + } + return msg +} + +func dialURLWithTransport(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) { + dialAddr := canonicalAddr(url) + + dialer, err := dialerFor(transport) + if err != nil { + dialer = nil + } + + switch url.Scheme { + case "http": + if dialer != nil { + return dialer(ctx, "tcp", dialAddr) + } + var d net.Dialer + return d.DialContext(ctx, "tcp", dialAddr) + case "https": + tlsConfig, err := tlsConfigForTransport(transport) + if err != nil { + tlsConfig = nil + } + + if dialer != nil { + netConn, err := dialer(ctx, "tcp", dialAddr) + if err != nil { + return nil, err + } + + if tlsConfig == nil { + tlsConfig = &tls.Config{InsecureSkipVerify: true} + } else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + inferredHost := dialAddr + if host, _, err := net.SplitHostPort(dialAddr); err == nil { + inferredHost = host + } + tlsConfigCopy := tlsConfig.Clone() + tlsConfigCopy.ServerName = inferredHost + tlsConfig = tlsConfigCopy + } + + if supportsHTTP11(tlsConfig.NextProtos) { + tlsConfig = tlsConfig.Clone() + tlsConfig.NextProtos = []string{"http/1.1"} + } + + tlsConn := tls.Client(netConn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + netConn.Close() + return nil, err + } + return tlsConn, nil + } + + tlsDialer := tls.Dialer{Config: tlsConfig} + return tlsDialer.DialContext(ctx, "tcp", dialAddr) + default: + return nil, fmt.Errorf("unknown scheme: %s", url.Scheme) + } +} + +func dialerFor(transport http.RoundTripper) (dialFunc, error) { + if transport == nil { + return nil, nil + } + + switch transport := transport.(type) { + case *http.Transport: + if transport.DialContext != nil { + return transport.DialContext, nil + } + if transport.Dial != nil { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return transport.Dial(network, addr) + }, nil + } + return nil, nil + case roundTripperWrapper: + return dialerFor(transport.WrappedRoundTripper()) + default: + return nil, fmt.Errorf("unknown transport type: %T", transport) + } +} + +func supportsHTTP11(nextProtos []string) bool { + if len(nextProtos) == 0 { + return true + } + for _, proto := range nextProtos { + if proto == "http/1.1" { + return true + } + } + return false +} diff --git a/vendor/k8s.io/streaming/pkg/httpstream/spdy/upgrade.go b/vendor/k8s.io/streaming/pkg/httpstream/spdy/upgrade.go new file mode 100644 index 000000000..df47029f4 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/spdy/upgrade.go @@ -0,0 +1,120 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spdy + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync/atomic" + "time" + + "k8s.io/streaming/pkg/httpstream" + "k8s.io/streaming/pkg/runtime" +) + +const HeaderSpdy31 = "SPDY/3.1" + +// responseUpgrader knows how to upgrade HTTP responses. It +// implements the httpstream.ResponseUpgrader interface. +type responseUpgrader struct { + pingPeriod time.Duration +} + +// connWrapper is used to wrap a hijacked connection and its bufio.Reader. All +// calls will be handled directly by the underlying net.Conn with the exception +// of Read and Close calls, which will consider data in the bufio.Reader. This +// ensures that data already inside the used bufio.Reader instance is also +// read. +type connWrapper struct { + net.Conn + closed int32 + bufReader *bufio.Reader +} + +func (w *connWrapper) Read(b []byte) (n int, err error) { + if atomic.LoadInt32(&w.closed) == 1 { + return 0, io.EOF + } + return w.bufReader.Read(b) +} + +func (w *connWrapper) Close() error { + err := w.Conn.Close() + atomic.StoreInt32(&w.closed, 1) + return err +} + +// NewResponseUpgrader returns a new httpstream.ResponseUpgrader that is +// capable of upgrading HTTP responses using SPDY/3.1 via the +// spdystream package. +func NewResponseUpgrader() httpstream.ResponseUpgrader { + return NewResponseUpgraderWithPings(0) +} + +// NewResponseUpgraderWithPings returns a new httpstream.ResponseUpgrader that +// is capable of upgrading HTTP responses using SPDY/3.1 via the spdystream +// package. +// +// If pingPeriod is non-zero, for each incoming connection a background +// goroutine will send periodic Ping frames to the server. Use this to keep +// idle connections through certain load balancers alive longer. +func NewResponseUpgraderWithPings(pingPeriod time.Duration) httpstream.ResponseUpgrader { + return responseUpgrader{pingPeriod: pingPeriod} +} + +// UpgradeResponse upgrades an HTTP response to one that supports multiplexed +// streams. newStreamHandler will be called synchronously whenever the +// other end of the upgraded connection creates a new stream. +func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { + connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) + upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) + if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { + errorMsg := fmt.Sprintf("unable to upgrade: missing upgrade headers in request: %#v", req.Header) + http.Error(w, errorMsg, http.StatusBadRequest) + return nil + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + errorMsg := "unable to upgrade: unable to hijack response" + http.Error(w, errorMsg, http.StatusInternalServerError) + return nil + } + + w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) + w.Header().Add(httpstream.HeaderUpgrade, HeaderSpdy31) + w.WriteHeader(http.StatusSwitchingProtocols) + + conn, bufrw, err := hijacker.Hijack() + if err != nil { + runtime.HandleErrorWithContext(req.Context(), err, "Unable to upgrade: error hijacking response") + return nil + } + + connWithBuf := &connWrapper{Conn: conn, bufReader: bufrw.Reader} + spdyConn, err := NewServerConnectionWithPings(connWithBuf, newStreamHandler, u.pingPeriod) + if err != nil { + runtime.HandleErrorWithContext(req.Context(), err, "Unable to upgrade: error creating SPDY server connection") + return nil + } + + return spdyConn +} diff --git a/vendor/k8s.io/streaming/pkg/httpstream/wsstream/conn.go b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/conn.go new file mode 100644 index 000000000..eae80ab4d --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/conn.go @@ -0,0 +1,466 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" + + "golang.org/x/net/websocket" + + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/httpstream" + "k8s.io/streaming/pkg/runtime" +) + +const WebSocketProtocolHeader = "Sec-Websocket-Protocol" + +// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating +// the channel number (zero indexed) the message was sent on. Messages in both directions should +// prefix their messages with this channel byte. When used for remote execution, the channel numbers +// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR +// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they +// are received by the server. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "channel.k8s.io" +// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN) +// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT) +// CLOSE +const ChannelWebSocketProtocol = "channel.k8s.io" + +// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character +// indicating the channel number (zero indexed) the message was sent on. Messages in both directions +// should prefix their messages with this channel char. When used for remote execution, the channel +// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, +// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be +// be valid) and data written by the server to the client is base64 encoded. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "base64.channel.k8s.io" +// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN) +// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT) +// CLOSE +const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io" + +const streamCloseSignal = 255 + +type codecType int + +const ( + rawCodec codecType = iota + base64Codec +) + +type ChannelType int + +const ( + IgnoreChannel ChannelType = iota + ReadChannel + WriteChannel + ReadWriteChannel +) + +// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers +// for WebSockets. +func IsWebSocketRequest(req *http.Request) bool { + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + return false + } + return httpstream.IsUpgradeRequest(req) +} + +// IsWebSocketRequestWithStreamCloseProtocol returns true if the request contains headers +// identifying that it is requesting a websocket upgrade with a remotecommand protocol +// version that supports the "CLOSE" signal; false otherwise. +func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool { + if !IsWebSocketRequest(req) { + return false + } + requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader)) + for _, requestedProtocol := range strings.Split(requestedProtocols, ",") { + if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) { + return true + } + } + + return false +} + +// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers +// identifying that it is requesting a websocket upgrade with a tunneling protocol; +// false otherwise. +func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool { + if !IsWebSocketRequest(req) { + return false + } + requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader)) + for _, requestedProtocol := range strings.Split(requestedProtocols, ",") { + if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) { + return true + } + } + + return false +} + +// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the +// read and write deadlines are pushed every time a new message is received. +// +// Contextual logging: IgnoreReceivesWithLogger should be used instead of IgnoreReceives in code which uses contextual logging. +func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) { + IgnoreReceivesWithLogger(klog.Background(), ws, timeout) +} + +// IgnoreReceivesWithLogger reads from a WebSocket until it is closed, then returns. If timeout is set, the +// read and write deadlines are pushed every time a new message is received. +func IgnoreReceivesWithLogger(logger klog.Logger, ws *websocket.Conn, timeout time.Duration) { + defer runtime.HandleCrashWithLogger(logger) + var data []byte + for { + resetTimeout(ws, timeout) + if err := websocket.Message.Receive(ws, &data); err != nil { + return + } + } +} + +// handshake ensures the provided user protocol matches one of the allowed protocols. It returns +// no error if no protocol is specified. +func handshake(config *websocket.Config, req *http.Request, allowed []string) error { + protocols := config.Protocol + if len(protocols) == 0 { + protocols = []string{""} + } + + for _, protocol := range protocols { + for _, allow := range allowed { + if allow == protocol { + config.Protocol = []string{protocol} + return nil + } + } + } + + return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed) +} + +// ChannelProtocolConfig describes a websocket subprotocol with channels. +type ChannelProtocolConfig struct { + Binary bool + Channels []ChannelType +} + +// NewDefaultChannelProtocols returns a channel protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given +// channels. +func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig { + return map[string]ChannelProtocolConfig{ + "": {Binary: true, Channels: channels}, + ChannelWebSocketProtocol: {Binary: true, Channels: channels}, + Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels}, + } +} + +// Conn supports sending multiple binary channels over a websocket connection. +type Conn struct { + protocols map[string]ChannelProtocolConfig + selectedProtocol string + channels []*websocketChannel + codec codecType + ready chan struct{} + ws *websocket.Conn + timeout time.Duration +} + +// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each +// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for +// future use. The channel types for each channel are passed as an array, supporting the different +// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer. +// +// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol +// name is used if websocket.Config.Protocol is empty. +func NewConn(protocols map[string]ChannelProtocolConfig) *Conn { + return &Conn{ + ready: make(chan struct{}), + protocols: protocols, + } +} + +// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, +// there is no timeout on the connection. +func (conn *Conn) SetIdleTimeout(duration time.Duration) { + conn.timeout = duration +} + +// SetWriteDeadline sets a timeout on writing to the websocket connection. The +// passed "duration" identifies how far into the future the write must complete +// by before the timeout fires. +func (conn *Conn) SetWriteDeadline(duration time.Duration) { + conn.ws.SetWriteDeadline(time.Now().Add(duration)) //nolint:errcheck +} + +// Open the connection and create channels for reading and writing. It returns +// the selected subprotocol, a slice of channels and an error. +func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) { + // serveHTTPComplete is channel that is closed/selected when "websocket#ServeHTTP" finishes. + serveHTTPComplete := make(chan struct{}) + // Ensure panic in spawned goroutine is propagated into the parent goroutine. + panicChan := make(chan any, 1) + go func() { + // If websocket server returns, propagate panic if necessary. Otherwise, + // signal HTTPServe finished by closing "serveHTTPComplete". + defer func() { + if p := recover(); p != nil { + panicChan <- p + } else { + close(serveHTTPComplete) + } + }() + websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req) + }() + + // In normal circumstances, "websocket.Server#ServeHTTP" calls "initialize" which closes + // "conn.ready" and then blocks until serving is complete. + select { + case <-conn.ready: + klog.FromContext(req.Context()).V(8).Info("websocket server initialized--serving") + case <-serveHTTPComplete: + // websocket server returned before completing initialization; cleanup and return error. + conn.closeNonThreadSafe() //nolint:errcheck + return "", nil, fmt.Errorf("websocket server finished before becoming ready") + case p := <-panicChan: + panic(p) + } + + rwc := make([]io.ReadWriteCloser, len(conn.channels)) + for i := range conn.channels { + rwc[i] = conn.channels[i] + } + return conn.selectedProtocol, rwc, nil +} + +func (conn *Conn) initialize(ws *websocket.Conn) { + negotiated := ws.Config().Protocol + conn.selectedProtocol = negotiated[0] + p := conn.protocols[conn.selectedProtocol] + if p.Binary { + conn.codec = rawCodec + } else { + conn.codec = base64Codec + } + conn.ws = ws + conn.channels = make([]*websocketChannel, len(p.Channels)) + for i, t := range p.Channels { + switch t { + case ReadChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false) + case WriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true) + case ReadWriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true) + case IgnoreChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false) + } + } + + close(conn.ready) +} + +func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error { + supportedProtocols := make([]string, 0, len(conn.protocols)) + for p := range conn.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) +} + +func (conn *Conn) resetTimeout() { + if conn.timeout > 0 { + conn.ws.SetDeadline(time.Now().Add(conn.timeout)) + } +} + +// closeNonThreadSafe cleans up by closing streams and the websocket +// connection *without* waiting for the "ready" channel. +func (conn *Conn) closeNonThreadSafe() error { + for _, s := range conn.channels { + s.Close() + } + var err error + if conn.ws != nil { + err = conn.ws.Close() + } + return err +} + +// Close is only valid after Open has been called +func (conn *Conn) Close() error { + <-conn.ready + return conn.closeNonThreadSafe() +} + +// protocolSupportsStreamClose returns true if the passed protocol +// supports the stream close signal (currently only V5 remotecommand); +// false otherwise. +func protocolSupportsStreamClose(protocol string) bool { + return protocol == "v5.channel.k8s.io" +} + +// protocolSupportsWebsocketTunneling returns true if the passed protocol +// is a tunneled Kubernetes spdy protocol; false otherwise. +func protocolSupportsWebsocketTunneling(protocol string) bool { + return strings.HasPrefix(protocol, "SPDY/3.1+") && strings.HasSuffix(protocol, ".k8s.io") +} + +// handle implements a websocket handler. +func (conn *Conn) handle(ws *websocket.Conn) { + conn.initialize(ws) + defer conn.Close() + supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol) + // conn.handle is typically used on the server-side and thus we have a request, + // but don't assume that and use klog.Background as fallback. + logger := klog.Background() + if req := ws.Request(); req != nil { + logger = klog.FromContext(req.Context()) + } + + for { + conn.resetTimeout() + var data []byte + if err := websocket.Message.Receive(ws, &data); err != nil { + if err != io.EOF { + logger.Error(err, "Error on socket receive") + } + break + } + if len(data) == 0 { + continue + } + if supportsStreamClose && data[0] == streamCloseSignal { + if len(data) != 2 { + logger.Error(nil, "Single channel byte should follow stream close signal", "receivedLength", len(data)-1) + break + } else { + channel := data[1] + if int(channel) >= len(conn.channels) { + logger.Error(nil, "Close is targeted for a channel that is not valid, possible protocol error", "channel", channel) + break + } + logger.V(4).Info("Received half-close signal from client, close stream", "channel", channel) + conn.channels[channel].Close() // After first Close, other closes are noop. + } + continue + } + channel := data[0] + if conn.codec == base64Codec { + channel = channel - '0' + } + data = data[1:] + if int(channel) >= len(conn.channels) { + logger.V(6).Info("Frame is targeted for a reader that is not valid, possible protocol error", "channel", channel) + continue + } + if _, err := conn.channels[channel].DataFromSocket(data); err != nil { + logger.Error(err, "Unable to write frame", "sendLength", len(data), "channel", channel, "err", err) + continue + } + } +} + +// write multiplexes the specified channel onto the websocket +func (conn *Conn) write(num byte, data []byte) (int, error) { + conn.resetTimeout() + switch conn.codec { + case rawCodec: + frame := make([]byte, len(data)+1) + frame[0] = num + copy(frame[1:], data) + if err := websocket.Message.Send(conn.ws, frame); err != nil { + return 0, err + } + case base64Codec: + frame := string('0'+num) + base64.StdEncoding.EncodeToString(data) + if err := websocket.Message.Send(conn.ws, frame); err != nil { + return 0, err + } + } + return len(data), nil +} + +// websocketChannel represents a channel in a connection +type websocketChannel struct { + conn *Conn + num byte + r io.Reader + w io.WriteCloser + + read, write bool +} + +// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe +// prior to the connection being opened. It may be no, half, or full duplex depending on +// read and write. +func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel { + r, w := io.Pipe() + return &websocketChannel{conn, num, r, w, read, write} +} + +func (p *websocketChannel) Write(data []byte) (int, error) { + if !p.write { + return len(data), nil + } + return p.conn.write(p.num, data) +} + +// DataFromSocket is invoked by the connection receiver to move data from the connection +// into a specific channel. +func (p *websocketChannel) DataFromSocket(data []byte) (int, error) { + if !p.read { + return len(data), nil + } + + switch p.conn.codec { + case rawCodec: + return p.w.Write(data) + case base64Codec: + dst := make([]byte, len(data)) + n, err := base64.StdEncoding.Decode(dst, data) + if err != nil { + return 0, err + } + return p.w.Write(dst[:n]) + } + return 0, nil +} + +func (p *websocketChannel) Read(data []byte) (int, error) { + if !p.read { + return 0, io.EOF + } + return p.r.Read(data) +} + +func (p *websocketChannel) Close() error { + return p.w.Close() +} diff --git a/vendor/k8s.io/streaming/pkg/httpstream/wsstream/doc.go b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/doc.go new file mode 100644 index 000000000..a57e8df60 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/doc.go @@ -0,0 +1,69 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package wsstream contains utilities for streaming content over WebSockets. +// The Conn type allows callers to multiplex multiple read/write channels over +// a single websocket. +// +// "channel.k8s.io" +// +// The Websocket RemoteCommand subprotocol "channel.k8s.io" prepends each binary message with a +// byte indicating the channel number (zero indexed) the message was sent on. Messages in both +// directions should prefix their messages with this channel byte. Used for remote execution, +// the channel numbers are by convention defined to match the POSIX file-descriptors assigned +// to STDIN, STDOUT, and STDERR (0, 1, and 2). No other conversion is performed on the raw +// subprotocol - writes are sent as they are received by the server. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "channel.k8s.io" +// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN) +// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT) +// CLOSE +// +// "v2.channel.k8s.io" +// +// The second Websocket subprotocol version "v2.channel.k8s.io" is the same as version 1, +// but it is the first "versioned" subprotocol. +// +// "v3.channel.k8s.io" +// +// The third version of the Websocket RemoteCommand subprotocol adds another channel +// for terminal resizing events. This channel is prepended with the byte '3', and it +// transmits two window sizes (encoding TerminalSize struct) with integers in the range +// (0,65536]. +// +// "v4.channel.k8s.io" +// +// The fourth version of the Websocket RemoteCommand subprotocol adds a channel for +// errors. This channel returns structured errors containing process exit codes. The +// error is "apierrors.StatusError{}". +// +// "v5.channel.k8s.io" +// +// The fifth version of the Websocket RemoteCommand subprotocol adds a CLOSE signal, +// which is sent as the first byte of the message. The second byte is the channel +// id. This CLOSE signal is handled by the websocket server by closing the stream, +// allowing the other streams to complete transmission if necessary, and gracefully +// shutdown the connection. +// +// Example client session: +// +// CONNECT http://server.com with subprotocol "v5.channel.k8s.io" +// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN) +// WRITE []byte{255, 0} # send CLOSE signal (STDIN) +// CLOSE +package wsstream diff --git a/vendor/k8s.io/streaming/pkg/httpstream/wsstream/stream.go b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/stream.go new file mode 100644 index 000000000..38cc41a23 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/httpstream/wsstream/stream.go @@ -0,0 +1,193 @@ +/* +Copyright 2015 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wsstream + +import ( + "context" + "encoding/base64" + "io" + "net/http" + "sync" + "time" + + "golang.org/x/net/websocket" + + "k8s.io/klog/v2" + "k8s.io/streaming/pkg/runtime" +) + +// The WebSocket subprotocol "binary.k8s.io" will only send messages to the +// client and ignore messages sent to the server. The received messages are +// the exact bytes written to the stream. Zero byte messages are possible. +const binaryWebSocketProtocol = "binary.k8s.io" + +// The WebSocket subprotocol "base64.binary.k8s.io" will only send messages to the +// client and ignore messages sent to the server. The received messages are +// a base64 version of the bytes written to the stream. Zero byte messages are +// possible. +const base64BinaryWebSocketProtocol = "base64.binary.k8s.io" + +// ReaderProtocolConfig describes a websocket subprotocol with one stream. +type ReaderProtocolConfig struct { + Binary bool +} + +// NewDefaultReaderProtocols returns a stream protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io". +func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig { + return map[string]ReaderProtocolConfig{ + "": {Binary: true}, + binaryWebSocketProtocol: {Binary: true}, + base64BinaryWebSocketProtocol: {Binary: false}, + } +} + +// Reader supports returning an arbitrary byte stream over a websocket channel. +type Reader struct { + logger klog.Logger + err chan error + r io.Reader + ping bool + timeout time.Duration + protocols map[string]ReaderProtocolConfig + selectedProtocol string + + handleCrash func(ctx context.Context, additionalHandlers ...func(context.Context, interface{})) // overridable for testing +} + +// NewReader creates a WebSocket pipe that will copy the contents of r to a provided +// WebSocket connection. If ping is true, a zero length message will be sent to the client +// before the stream begins reading. +// +// The protocols parameter maps subprotocol names to StreamProtocols. The empty string +// subprotocol name is used if websocket.Config.Protocol is empty. +// +//logcheck:context // NewReaderWithLogger should be used instead of NewReader in code which supports contextual logging. +func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader { + return NewReaderWithLogger(klog.Background(), r, ping, protocols) +} + +// NewReaderWithLogger creates a WebSocket pipe that will copy the contents of r to a provided +// WebSocket connection. If ping is true, a zero length message will be sent to the client +// before the stream begins reading. +// +// The protocols parameter maps subprotocol names to StreamProtocols. The empty string +// subprotocol name is used if websocket.Config.Protocol is empty. +func NewReaderWithLogger(logger klog.Logger, r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader { + return &Reader{ + logger: logger, + r: r, + err: make(chan error), + ping: ping, + protocols: protocols, + handleCrash: runtime.HandleCrashWithContext, + } +} + +// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, +// there is no timeout on the reader. +func (r *Reader) SetIdleTimeout(duration time.Duration) { + r.timeout = duration +} + +func (r *Reader) handshake(config *websocket.Config, req *http.Request) error { + supportedProtocols := make([]string, 0, len(r.protocols)) + for p := range r.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) +} + +// Copy the reader to the response. The created WebSocket is closed after this +// method completes. +func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error { + go func() { + defer r.handleCrash(req.Context()) + websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req) + }() + return <-r.err +} + +// handle implements a WebSocket handler. +func (r *Reader) handle(ws *websocket.Conn) { + // Close the connection when the client requests it, or when we finish streaming, whichever happens first + closeConnOnce := &sync.Once{} + closeConn := func() { + closeConnOnce.Do(func() { + ws.Close() + }) + } + + negotiated := ws.Config().Protocol + r.selectedProtocol = negotiated[0] + defer close(r.err) + defer closeConn() + + go func() { + defer runtime.HandleCrashWithLogger(r.logger) + // This blocks until the connection is closed. + // Client should not send anything. + IgnoreReceivesWithLogger(r.logger, ws, r.timeout) + // Once the client closes, we should also close + closeConn() + }() + + r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout) +} + +func resetTimeout(ws *websocket.Conn, timeout time.Duration) { + if timeout > 0 { + ws.SetDeadline(time.Now().Add(timeout)) + } +} + +func messageCopy(ws *websocket.Conn, r io.Reader, base64Encode, ping bool, timeout time.Duration) error { + buf := make([]byte, 2048) + if ping { + resetTimeout(ws, timeout) + if base64Encode { + if err := websocket.Message.Send(ws, ""); err != nil { + return err + } + } else { + if err := websocket.Message.Send(ws, []byte{}); err != nil { + return err + } + } + } + for { + resetTimeout(ws, timeout) + n, err := r.Read(buf) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if n > 0 { + if base64Encode { + if err := websocket.Message.Send(ws, base64.StdEncoding.EncodeToString(buf[:n])); err != nil { + return err + } + } else { + if err := websocket.Message.Send(ws, buf[:n]); err != nil { + return err + } + } + } + } +} diff --git a/vendor/k8s.io/streaming/pkg/runtime/runtime.go b/vendor/k8s.io/streaming/pkg/runtime/runtime.go new file mode 100644 index 000000000..25a70b735 --- /dev/null +++ b/vendor/k8s.io/streaming/pkg/runtime/runtime.go @@ -0,0 +1,62 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package runtime + +import ( + "context" + "fmt" + + "k8s.io/klog/v2" +) + +// HandleError logs an asynchronous error. +func HandleError(err error) { + if err == nil { + return + } + klog.Background().Error(err, "Unhandled Error") +} + +// HandleErrorWithContext logs an asynchronous error with contextual logging when available. +func HandleErrorWithContext(ctx context.Context, err error, msg string, keysAndValues ...interface{}) { + if err == nil { + return + } + klog.FromContext(ctx).Error(err, msg, keysAndValues...) +} + +// HandleCrash recovers from panic and logs it. +func HandleCrash() { + HandleCrashWithLogger(klog.Background()) +} + +// HandleCrashWithContext recovers from panic and logs it with the context logger. +func HandleCrashWithContext(ctx context.Context, additionalHandlers ...func(context.Context, interface{})) { + if r := recover(); r != nil { + for _, fn := range additionalHandlers { + fn(ctx, r) + } + klog.FromContext(ctx).Error(fmt.Errorf("%v", r), "Observed a panic") + } +} + +// HandleCrashWithLogger recovers from panic and logs it using the provided logger. +func HandleCrashWithLogger(logger klog.Logger) { + if r := recover(); r != nil { + logger.Error(fmt.Errorf("%v", r), "Observed a panic") + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index ee579282b..9929929b5 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -507,6 +507,9 @@ github.com/googleapis/gax-go/v2/internallog github.com/googleapis/gax-go/v2/internallog/grpclog github.com/googleapis/gax-go/v2/internallog/internal github.com/googleapis/gax-go/v2/iterator +# github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 +## explicit; go 1.20 +github.com/gorilla/websocket # github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 ## explicit; go 1.25.0 github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule @@ -597,6 +600,10 @@ github.com/mitchellh/copystructure # github.com/mitchellh/reflectwalk v1.0.2 ## explicit github.com/mitchellh/reflectwalk +# github.com/moby/spdystream v0.5.1 +## explicit; go 1.13 +github.com/moby/spdystream +github.com/moby/spdystream/spdy # github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd ## explicit github.com/modern-go/concurrent @@ -902,8 +909,11 @@ golang.org/x/net/http2/hpack golang.org/x/net/idna golang.org/x/net/internal/httpcommon golang.org/x/net/internal/httpsfv +golang.org/x/net/internal/socks golang.org/x/net/internal/timeseries +golang.org/x/net/proxy golang.org/x/net/trace +golang.org/x/net/websocket # golang.org/x/oauth2 v0.36.0 ## explicit; go 1.25.0 golang.org/x/oauth2 @@ -1226,6 +1236,8 @@ k8s.io/apimachinery/pkg/util/cache k8s.io/apimachinery/pkg/util/diff k8s.io/apimachinery/pkg/util/errors k8s.io/apimachinery/pkg/util/framer +k8s.io/apimachinery/pkg/util/httpstream +k8s.io/apimachinery/pkg/util/httpstream/spdy k8s.io/apimachinery/pkg/util/intstr k8s.io/apimachinery/pkg/util/json k8s.io/apimachinery/pkg/util/managedfields @@ -1233,6 +1245,7 @@ k8s.io/apimachinery/pkg/util/managedfields/internal k8s.io/apimachinery/pkg/util/mergepatch k8s.io/apimachinery/pkg/util/naming k8s.io/apimachinery/pkg/util/net +k8s.io/apimachinery/pkg/util/remotecommand k8s.io/apimachinery/pkg/util/runtime k8s.io/apimachinery/pkg/util/sets k8s.io/apimachinery/pkg/util/strategicpatch @@ -1570,11 +1583,15 @@ k8s.io/client-go/tools/leaderelection/resourcelock k8s.io/client-go/tools/metrics k8s.io/client-go/tools/pager k8s.io/client-go/tools/reference +k8s.io/client-go/tools/remotecommand k8s.io/client-go/transport +k8s.io/client-go/transport/spdy +k8s.io/client-go/transport/websocket k8s.io/client-go/util/apply k8s.io/client-go/util/cert k8s.io/client-go/util/connrotation k8s.io/client-go/util/consistencydetector +k8s.io/client-go/util/exec k8s.io/client-go/util/flowcontrol k8s.io/client-go/util/homedir k8s.io/client-go/util/keyutil @@ -1628,6 +1645,12 @@ k8s.io/kube-openapi/pkg/validation/strfmt/bson ## explicit; go 1.26.0 k8s.io/kubernetes/pkg/apis/core k8s.io/kubernetes/pkg/apis/core/helper +# k8s.io/streaming v0.36.2 +## explicit; go 1.26.0 +k8s.io/streaming/pkg/httpstream +k8s.io/streaming/pkg/httpstream/spdy +k8s.io/streaming/pkg/httpstream/wsstream +k8s.io/streaming/pkg/runtime # k8s.io/utils v0.0.0-20260507154919-ff6756f316d2 ## explicit; go 1.23 k8s.io/utils/buffer From 380c582b098ff09f39ec520798e4ffb7296aa0b2 Mon Sep 17 00:00:00 2001 From: Kayne Tu Date: Wed, 17 Jun 2026 11:32:32 -0700 Subject: [PATCH 2/4] chore: Address coderabbit nit comments --- demos/cuj1-slinky-slurm.md | 10 +++++----- .../conformance/slinky_slurm_health_check.go | 12 ++++++++++-- .../slinky_slurm_health_check_test.go | 17 ++++++++++++++++- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/demos/cuj1-slinky-slurm.md b/demos/cuj1-slinky-slurm.md index 919d12f5a..d1b20683a 100644 --- a/demos/cuj1-slinky-slurm.md +++ b/demos/cuj1-slinky-slurm.md @@ -8,7 +8,7 @@ Slurm leaves are built from criteria flags (`--service`, `--platform slurm`, … - `kubectl` is configured for the target cluster. - GPU leaves assume H100 nodes with drivers (or Kind for the CPU-only path). -- Node pools use a `**nodeGroup`** label (adjust if your cluster uses different keys). +- Node pools use a `nodeGroup` label (adjust if your cluster uses different keys). - Inspect taints before bundling: `kubectl get nodes -o custom-columns=NAME:.metadata.name,GROUP:.metadata.labels.nodeGroup,TAINTS:.spec.taints` ## Workflow @@ -48,7 +48,7 @@ AICR injects placement from bundle flags using each component's registry paths: | Flag | Typical targets | | --------------------------------------------------------------- | --------------------------------------------------- | | `--system-node-selector` / `--system-node-toleration` | cert-manager, **slurm-operator**, prometheus, … | -| `--accelerated-node-selector` / `--accelerated-node-toleration` | `**nodesets.slinky`** (slurmd workers) | +| `--accelerated-node-selector` / `--accelerated-node-toleration` | `nodesets.slinky` (slurmd workers) | | `--set-json slinkyslurm:…` | Per-leaf overrides on the cluster chart (see below) | @@ -154,7 +154,7 @@ Use **deployment** and **conformance**. Performance validation is **not supporte | Phase | What it checks | | ------------- | ---------------------------------------------------------------------------------------------------------------------- | | `deployment` | Component Chainsaw health (CRs, Deployments, DaemonSets ready), including `slinky-slurm` readiness (long retry budget) | -| `conformance` | `slinky-slurm-health`: `scontrol ping`, idle/mix node gate, bounded `srun --immediate=5 --time=0:01 hostname` | +| `conformance` | `slinky-slurm-health`: `scontrol ping`, idle/mix node gate, bounded `srun --immediate=5 --time=0:03 hostname` | | `performance` | **Not supported yet** on slurm leaves | | `all` | Runs deployment → conformance → performance in sequence; the performance step has nothing to run on slurm leaves | @@ -195,7 +195,7 @@ aicr validate \ ### Scheduling flags on validate -When validate captures cluster state inline (no `-s`), pass `**--node-selector**` and `**--toleration**` so the snapshot agent Job can schedule on tainted nodes. Match your **system** pool (not the GPU pool) unless you intend to run the agent on GPU nodes. +When validate captures cluster state inline (no `-s`), pass `--node-selector` and `--toleration` so the snapshot agent Job can schedule on tainted nodes. Match your **system** pool (not the GPU pool) unless you intend to run the agent on GPU nodes. **EKS example** (agent on system nodes): @@ -233,7 +233,7 @@ SSH is disabled by default on the login chart; use `kubectl exec`. ```shell kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- sinfo kubectl exec -n slurm deploy/slinky-slurm-login-slinky -- \ - srun --immediate=5 --time=0:01 hostname + srun --immediate=5 --time=0:03 hostname ``` Multi-node (when `replicas >= 2`): diff --git a/validators/conformance/slinky_slurm_health_check.go b/validators/conformance/slinky_slurm_health_check.go index cca78ac83..528ab7433 100644 --- a/validators/conformance/slinky_slurm_health_check.go +++ b/validators/conformance/slinky_slurm_health_check.go @@ -284,8 +284,16 @@ func listSlinkySetsForController( if item.GetAPIVersion() != "slinky.slurm.net/v1beta1" || item.GetKind() != kind { continue } - controllerName, _, _ := unstructured.NestedString(item.Object, "spec", "controllerRef", "name") - controllerNamespace, _, _ := unstructured.NestedString(item.Object, "spec", "controllerRef", "namespace") + controllerName, _, controllerNameErr := unstructured.NestedString(item.Object, "spec", "controllerRef", "name") + if controllerNameErr != nil { + return nil, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to read controllerRef.name from %s/%s", kind, item.GetName()), controllerNameErr) + } + controllerNamespace, _, controllerNamespaceErr := unstructured.NestedString(item.Object, "spec", "controllerRef", "namespace") + if controllerNamespaceErr != nil { + return nil, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("failed to read controllerRef.namespace from %s/%s", kind, item.GetName()), controllerNamespaceErr) + } if controllerName != slinkySlurmComponent { continue } diff --git a/validators/conformance/slinky_slurm_health_check_test.go b/validators/conformance/slinky_slurm_health_check_test.go index c8f8ade3b..c7bb6541b 100644 --- a/validators/conformance/slinky_slurm_health_check_test.go +++ b/validators/conformance/slinky_slurm_health_check_test.go @@ -177,7 +177,7 @@ func TestCheckSlinkySlurmHealthRunsAllHealthCommands(t *testing.T) { wantCommands := []string{ "scontrol ping", "/bin/sh -c " + slinkySlurmSinfoIdleMixShell, - "srun --immediate=5 --time=0:01 hostname", + "srun --immediate=5 --time=0:03 hostname", } if strings.Join(gotCommands, ",") != strings.Join(wantCommands, ",") { t.Fatalf("commands = %v, want %v", gotCommands, wantCommands) @@ -202,6 +202,21 @@ func TestCheckSlinkySlurmHealthDiscoversPodsFromSlinkyCRSelectors(t *testing.T) } } +func TestCheckSlinkySlurmHealthFailsOnMalformedControllerRefName(t *testing.T) { + ctx := slurmReadyTestContext(t, false) + loginSet := defaultLoginSet() + if err := unstructured.SetNestedField(loginSet.Object, map[string]any{"bad": "shape"}, + "spec", "controllerRef", "name"); err != nil { + t.Fatalf("set malformed controllerRef.name: %v", err) + } + ctx.DynamicClient = newSlinkyDynamicClient(t, loginSet, defaultNodeSet()) + + err := CheckSlinkySlurmHealth(ctx) + if err == nil || !strings.Contains(err.Error(), "failed to read controllerRef.name") { + t.Fatalf("error = %v, want malformed controllerRef.name read failure", err) + } +} + func TestCheckSlinkySlurmHealthCollectsAllCommandFailures(t *testing.T) { restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, _ string, command []string) (podExecResult, error) { joined := strings.Join(command, " ") From cdd1684ff4e0ac4dadac4dce9cb44e1eedd2e7b7 Mon Sep 17 00:00:00 2001 From: Kayne Tu Date: Mon, 22 Jun 2026 13:58:22 -0700 Subject: [PATCH 3/4] chore: address PR comments --- pkg/recipe/metadata_store_test.go | 62 +++++++++ .../h100-eks-ubuntu-training-slurm.yaml | 3 +- .../overlays/h100-gke-cos-training-slurm.yaml | 5 +- validators/conformance/pod_exec.go | 43 ++++++- validators/conformance/pod_exec_test.go | 115 ++++++++++++++++- .../conformance/robust_controller_check.go | 5 +- .../robust_controller_check_test.go | 27 ++++ .../conformance/slinky_slurm_health_check.go | 87 +++++++------ .../slinky_slurm_health_check_test.go | 121 +++++++++++++++++- 9 files changed, 411 insertions(+), 57 deletions(-) diff --git a/pkg/recipe/metadata_store_test.go b/pkg/recipe/metadata_store_test.go index ca197bee4..aea6c6483 100644 --- a/pkg/recipe/metadata_store_test.go +++ b/pkg/recipe/metadata_store_test.go @@ -655,6 +655,68 @@ func TestSlurmLeavesClearInheritedPerformancePhase(t *testing.T) { } } +func TestSlurmLeavesAppendConformanceHealthCheck(t *testing.T) { + ctx := context.Background() + store, err := loadMetadataStore(ctx) + if err != nil { + t.Fatalf("failed to load metadata store: %v", err) + } + + conformanceChecks := []string{ + "platform-health", + "gpu-operator-health", + "dra-support", + "accelerator-metrics", + "ai-service-metrics", + "gang-scheduling", + "pod-autoscaling", + "cluster-autoscaling", + "robust-controller", + "secure-accelerator-access", + "slinky-slurm-health", + } + kindConformanceChecks := []string{ + "platform-health", + "gpu-operator-health", + "dra-support", + "accelerator-metrics", + "ai-service-metrics", + "gang-scheduling", + "secure-accelerator-access", + "pod-autoscaling", + "cluster-autoscaling", + "slinky-slurm-health", + } + + tests := []struct { + name string + want []string + }{ + {name: "h100-eks-ubuntu-training-slurm", want: conformanceChecks}, + {name: "h100-gke-cos-training-slurm", want: conformanceChecks}, + {name: "h100-kind-training-slurm", want: kindConformanceChecks}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + leaf, ok := store.GetRecipeByName(tt.name) + if !ok { + t.Fatalf("overlay %q not found in store", tt.name) + } + result, err := store.BuildRecipeResult(ctx, leaf.Spec.Criteria) + if err != nil { + t.Fatalf("BuildRecipeResult failed: %v", err) + } + if result.Validation == nil || result.Validation.Conformance == nil { + t.Fatalf("conformance phase missing from resolved recipe") + } + if got := result.Validation.Conformance.Checks; !slices.Equal(got, tt.want) { + t.Errorf("conformance.checks = %v, want %v", got, tt.want) + } + }) + } +} + // TestEvaluatorFailingLeafExcludesCandidate verifies that when a leaf overlay's // constraints fail evaluation, no ancestor overlay is used as a fallback // candidate. With maximal leaf selection, ancestors are not independent diff --git a/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml b/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml index 885cbe4a3..e35bc745a 100644 --- a/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml +++ b/recipes/overlays/h100-eks-ubuntu-training-slurm.yaml @@ -93,7 +93,8 @@ spec: # wrong path. The equivalent signal here is a slurm-launched # `srun nccl-tests/all_reduce_perf` that goes through slurmd + the # EFA libfabric stack already present on the parent EKS leaf. - # Deployment and conformance checks are inherited unchanged. + # Deployment checks are inherited unchanged. Conformance keeps the + # inherited parent checks and appends the Slurm-specific health check below. validation: conformance: checks: diff --git a/recipes/overlays/h100-gke-cos-training-slurm.yaml b/recipes/overlays/h100-gke-cos-training-slurm.yaml index bd59326c6..4093b7a3e 100644 --- a/recipes/overlays/h100-gke-cos-training-slurm.yaml +++ b/recipes/overlays/h100-gke-cos-training-slurm.yaml @@ -95,8 +95,9 @@ spec: # wrong path. The equivalent signal here is a slurm-launched # `srun nccl-tests/all_reduce_perf` that goes through slurmd + the # GPUDirect TCPXO plugin already deployed by the parent leaf via - # gke-nccl-tcpxo. Deployment and conformance checks are inherited - # unchanged. + # gke-nccl-tcpxo. Deployment checks are inherited unchanged. Conformance + # keeps the inherited parent checks and appends the Slurm-specific health + # check below. validation: conformance: checks: diff --git a/validators/conformance/pod_exec.go b/validators/conformance/pod_exec.go index 6347289a1..bc7127123 100644 --- a/validators/conformance/pod_exec.go +++ b/validators/conformance/pod_exec.go @@ -32,13 +32,42 @@ import ( k8sexec "k8s.io/client-go/util/exec" ) +type podExecOptions struct { + DefaultContainerAnnotation string + PreferredContainerName string +} + type podExecResult struct { Stdout string Stderr string ExitCode int } -type podExecFunc func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) +// selectExecContainer chooses which container to exec into using caller-provided +// preferences. It first honors a configured default-container annotation when it +// names an existing container, then a configured preferred container name, then +// falls back to the pod's first container. Callers must ensure the pod has at +// least one container. +func selectExecContainer(pod *corev1.Pod, opts podExecOptions) string { + containers := pod.Spec.Containers + if annotated := pod.Annotations[opts.DefaultContainerAnnotation]; opts.DefaultContainerAnnotation != "" && annotated != "" { + for i := range containers { + if containers[i].Name == annotated { + return annotated + } + } + } + if opts.PreferredContainerName != "" { + for i := range containers { + if containers[i].Name == opts.PreferredContainerName { + return opts.PreferredContainerName + } + } + } + return containers[0].Name +} + +type podExecFunc func(context.Context, *validators.Context, string, string, []string, podExecOptions) (podExecResult, error) type podExecExecutorFactory func(*rest.Config, string, string) (remotecommand.Executor, error) @@ -50,7 +79,15 @@ var newPodExecExecutor podExecExecutorFactory = func(config *rest.Config, method return remotecommand.NewSPDYExecutor(config, method, parsedURL) } -func execPodCommand(streamCtx context.Context, ctx *validators.Context, namespace, podName string, command []string) (podExecResult, error) { +func execPodCommand( + streamCtx context.Context, + ctx *validators.Context, + namespace string, + podName string, + command []string, + opts podExecOptions, +) (podExecResult, error) { + pod, err := ctx.Clientset.CoreV1().Pods(namespace).Get(streamCtx, podName, metav1.GetOptions{}) if err != nil { return podExecResult{}, errors.Wrap(errors.ErrCodeInternal, @@ -67,7 +104,7 @@ func execPodCommand(streamCtx context.Context, ctx *validators.Context, namespac Namespace(namespace). SubResource("exec"). VersionedParams(&corev1.PodExecOptions{ - Container: pod.Spec.Containers[0].Name, + Container: selectExecContainer(pod, opts), Command: command, Stdout: true, Stderr: true, diff --git a/validators/conformance/pod_exec_test.go b/validators/conformance/pod_exec_test.go index b35cdc573..4ff396a0a 100644 --- a/validators/conformance/pod_exec_test.go +++ b/validators/conformance/pod_exec_test.go @@ -63,7 +63,7 @@ func TestExecPodCommandBuildsExecRequestAndStreamsOutput(t *testing.T) { }) defer restore() - result, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"srun", "hostname"}) + result, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"srun", "hostname"}, podExecOptions{}) if err != nil { t.Fatalf("execPodCommand() error = %v", err) } @@ -93,6 +93,113 @@ func TestExecPodCommandBuildsExecRequestAndStreamsOutput(t *testing.T) { } } +func TestSelectExecContainer(t *testing.T) { + const defaultContainerAnnotationForTest = "example.com/default-container" + + tests := []struct { + name string + annotations map[string]string + options podExecOptions + containers []string + want string + }{ + { + name: "configured default-container annotation wins", + annotations: map[string]string{defaultContainerAnnotationForTest: "login"}, + options: podExecOptions{DefaultContainerAnnotation: defaultContainerAnnotationForTest}, + containers: []string{"sidecar", "login"}, + want: "login", + }, + { + name: "annotation ignored when not a real container", + annotations: map[string]string{defaultContainerAnnotationForTest: "ghost"}, + options: podExecOptions{ + DefaultContainerAnnotation: defaultContainerAnnotationForTest, + PreferredContainerName: "login", + }, + containers: []string{"sidecar", "login"}, + want: "login", + }, + { + name: "preferred container matched when no annotation", + options: podExecOptions{PreferredContainerName: "login"}, + containers: []string{"munge", "login"}, + want: "login", + }, + { + name: "fallback to first container when nothing matches", + options: podExecOptions{PreferredContainerName: "login"}, + containers: []string{"sidecar-a", "sidecar-b"}, + want: "sidecar-a", + }, + { + name: "empty options use first container", + want: "sssd", + containers: []string{"sssd", "login"}, + }, + { + name: "annotation is ignored unless configured", + annotations: map[string]string{defaultContainerAnnotationForTest: "login"}, + containers: []string{"sidecar", "login"}, + want: "sidecar", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + containers := make([]corev1.Container, 0, len(tt.containers)) + for _, name := range tt.containers { + containers = append(containers, corev1.Container{Name: name}) + } + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "login-0", Namespace: "slurm"}, + Spec: corev1.PodSpec{Containers: containers}, + } + if len(tt.annotations) > 0 { + pod.Annotations = tt.annotations + } + if got := selectExecContainer(pod, tt.options); got != tt.want { + t.Fatalf("selectExecContainer() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestExecPodCommandHonorsConfiguredDefaultContainerAnnotation(t *testing.T) { + const defaultContainerAnnotationForTest = "example.com/default-container" + + ctx := podExecHTTPContext(t, corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "login-0", + Namespace: "slurm", + Annotations: map[string]string{defaultContainerAnnotationForTest: "login"}, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "sidecar"}, + {Name: "login"}, + }, + }, + }) + + var gotURL string + restore := replacePodExecExecutorForTest(func(_ *rest.Config, _ string, url string) (remotecommand.Executor, error) { + gotURL = url + return fakePodExecutor{ + stream: func(context.Context, remotecommand.StreamOptions) error { return nil }, + }, nil + }) + defer restore() + + opts := podExecOptions{DefaultContainerAnnotation: defaultContainerAnnotationForTest} + if _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}, opts); err != nil { + t.Fatalf("execPodCommand() error = %v", err) + } + if !strings.Contains(gotURL, "container=login") { + t.Fatalf("exec URL = %s, want container=login (not first container)", gotURL) + } +} + func TestExecPodCommandReturnsPreStreamErrors(t *testing.T) { tests := []struct { name string @@ -123,7 +230,7 @@ func TestExecPodCommandReturnsPreStreamErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := execPodCommand(context.Background(), tt.ctx, "slurm", "missing", []string{"hostname"}) + _, err := execPodCommand(context.Background(), tt.ctx, "slurm", "missing", []string{"hostname"}, podExecOptions{}) if err == nil || !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("error = %v, want containing %q", err, tt.wantErr) } @@ -143,7 +250,7 @@ func TestExecPodCommandReturnsExecutorFactoryError(t *testing.T) { }) defer restore() - _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}) + _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}, podExecOptions{}) if err == nil || !strings.Contains(err.Error(), "failed to create pod exec executor") || !strings.Contains(err.Error(), "factory failed") { t.Fatalf("error = %v, want wrapped factory failure", err) } @@ -165,7 +272,7 @@ func TestExecPodCommandReturnsStreamError(t *testing.T) { }) defer restore() - _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}) + _, err := execPodCommand(context.Background(), ctx, "slurm", "login-0", []string{"hostname"}, podExecOptions{}) if err == nil || !strings.Contains(err.Error(), "stream failed") { t.Fatalf("error = %v, want stream failure", err) } diff --git a/validators/conformance/robust_controller_check.go b/validators/conformance/robust_controller_check.go index e156cf8f1..cc860d5ca 100644 --- a/validators/conformance/robust_controller_check.go +++ b/validators/conformance/robust_controller_check.go @@ -52,13 +52,14 @@ type webhookRejectionReport struct { Message string } -// recipeHasComponent checks if a named component exists in the validation's componentRefs. +// recipeHasComponent checks if a named component exists and is enabled in the +// validation's componentRefs. func recipeHasComponent(ctx *validators.Context, name string) bool { if ctx.ValidationInput == nil { return false } for _, ref := range ctx.ValidationInput.ComponentRefs { - if ref.Name == name { + if ref.Name == name && ref.IsEnabled() { return true } } diff --git a/validators/conformance/robust_controller_check_test.go b/validators/conformance/robust_controller_check_test.go index 6617c8d49..63e12b7d8 100644 --- a/validators/conformance/robust_controller_check_test.go +++ b/validators/conformance/robust_controller_check_test.go @@ -56,6 +56,19 @@ func TestRecipeHasComponent(t *testing.T) { component: "kubeflow-trainer", want: true, }, + { + name: "component disabled", + recipe: &recipe.RecipeResult{ + ComponentRefs: []recipe.ComponentRef{ + { + Name: "kubeflow-trainer", + Overrides: map[string]any{"enabled": false}, + }, + }, + }, + component: "kubeflow-trainer", + want: false, + }, { name: "component not present", recipe: &recipe.RecipeResult{ @@ -133,6 +146,20 @@ func TestCheckRobustControllerRouting(t *testing.T) { // Will fail because fake clientset has no deployments, but proves routing works expectContains: "Kubeflow Trainer controller not found", }, + { + name: "recipe with disabled kubeflow-trainer skips", + recipe: &recipe.RecipeResult{ + ComponentRefs: []recipe.ComponentRef{ + { + Name: "kubeflow-trainer", + Overrides: map[string]any{"enabled": false}, + }, + {Name: "gpu-operator"}, + }, + }, + expectSkip: true, + expectContains: "no supported AI operator", + }, { name: "recipe with dynamo-platform routes to dynamo check", recipe: &recipe.RecipeResult{ diff --git a/validators/conformance/slinky_slurm_health_check.go b/validators/conformance/slinky_slurm_health_check.go index 528ab7433..fb7d029f8 100644 --- a/validators/conformance/slinky_slurm_health_check.go +++ b/validators/conformance/slinky_slurm_health_check.go @@ -29,9 +29,11 @@ import ( ) const ( - slinkySlurmComponent = "slinky-slurm" - slinkySlurmNamespace = "slurm" - kwokNodeAnnotation = "kwok.x-k8s.io/node" + slinkySlurmComponent = "slinky-slurm" + slinkySlurmNamespace = "slurm" + kwokNodeAnnotation = "kwok.x-k8s.io/node" + defaultContainerAnnotation = "kubectl.kubernetes.io/default-container" + slinkyLoginPodContainerName = "login" ) var ( @@ -77,6 +79,11 @@ var slinkySlurmHealthCommands = []slinkySlurmHealthCommand{ var slinkyExecCommand podExecFunc = execPodCommand +var slinkyLoginPodExecOptions = podExecOptions{ + DefaultContainerAnnotation: defaultContainerAnnotation, + PreferredContainerName: slinkyLoginPodContainerName, +} + // CheckSlinkySlurmHealth validates that a Slinky-managed Slurm cluster is // reachable from the login pod, has idle or mixed worker nodes, and can // schedule a minimal job without queueing indefinitely. @@ -90,24 +97,25 @@ func CheckSlinkySlurmHealth(ctx *validators.Context) error { if ctx.ValidationInput == nil { return errors.New(errors.ErrCodeInvalidRequest, "validation is not available") } - if !recipeHasEnabledComponent(ctx, slinkySlurmComponent) { + if !recipeHasComponent(ctx, slinkySlurmComponent) { return validators.Skip("slinky-slurm component not present in recipe") } + namespace := resolveSlinkySlurmNamespace(ctx) if err := discoverSlinkySetAPIs(ctx); err != nil { return err } - if err := skipIfAllNodeSetPodsAreKWOK(ctx); err != nil { + if err := skipIfAllNodeSetPodsAreKWOK(ctx, namespace); err != nil { return err } - loginPod, err := findReadySlinkyLoginPod(ctx) + loginPod, err := findReadySlinkyLoginPod(ctx, namespace) if err != nil { return err } - recordSlinkyInventories(ctx, loginPod) + recordSlinkyInventories(ctx, namespace, loginPod) - failures := runSlinkySlurmHealthCommands(ctx, loginPod.Name) + failures := runSlinkySlurmHealthCommands(ctx, namespace, loginPod.Name) if len(failures) > 0 { return errors.New(errors.ErrCodeInternal, "Slinky Slurm health commands failed:\n"+strings.Join(failures, "\n")) @@ -116,23 +124,24 @@ func CheckSlinkySlurmHealth(ctx *validators.Context) error { return nil } -func recipeHasEnabledComponent(ctx *validators.Context, name string) bool { +func resolveSlinkySlurmNamespace(ctx *validators.Context) string { if ctx.ValidationInput == nil { - return false + return slinkySlurmNamespace } for _, ref := range ctx.ValidationInput.ComponentRefs { - if ref.Name == name && ref.IsEnabled() { - return true + if ref.Name == slinkySlurmComponent && ref.IsEnabled() && strings.TrimSpace(ref.Namespace) != "" { + return ref.Namespace } } - return false + return slinkySlurmNamespace } -func runSlinkySlurmHealthCommands(ctx *validators.Context, loginPodName string) []string { +func runSlinkySlurmHealthCommands(ctx *validators.Context, namespace, loginPodName string) []string { var failures []string for _, check := range slinkySlurmHealthCommands { - result, execErr := slinkyExecCommand(ctx.Ctx, ctx, slinkySlurmNamespace, loginPodName, check.command) - recordSlinkyExecResult(ctx, loginPodName, check, result, execErr) + result, execErr := slinkyExecCommand( + ctx.Ctx, ctx, namespace, loginPodName, check.command, slinkyLoginPodExecOptions) + recordSlinkyExecResult(ctx, namespace, loginPodName, check, result, execErr) if execErr != nil { failures = append(failures, fmt.Sprintf("%s: exec failed: %v", check.label, execErr)) continue @@ -171,8 +180,8 @@ func discoverSlinkySetAPIs(ctx *validators.Context) error { return nil } -func skipIfAllNodeSetPodsAreKWOK(ctx *validators.Context) error { - pods, err := listSlinkyNodeSetPods(ctx) +func skipIfAllNodeSetPodsAreKWOK(ctx *validators.Context, namespace string) error { + pods, err := listSlinkyNodeSetPods(ctx, namespace) if err != nil { return err } @@ -201,17 +210,18 @@ func skipIfAllNodeSetPodsAreKWOK(ctx *validators.Context) error { return nil } -func listSlinkyNodeSetPods(ctx *validators.Context) ([]corev1.Pod, error) { - return listPodsForSlinkySetSelectors(ctx, slinkyNodeSetGVR, "NodeSet") +func listSlinkyNodeSetPods(ctx *validators.Context, namespace string) ([]corev1.Pod, error) { + return listPodsForSlinkySetSelectors(ctx, namespace, slinkyNodeSetGVR, "NodeSet") } func listPodsForSlinkySetSelectors( ctx *validators.Context, + namespace string, gvr schema.GroupVersionResource, kind string, ) ([]corev1.Pod, error) { - sets, err := listSlinkySetsForController(ctx, gvr, kind) + sets, err := listSlinkySetsForController(ctx, namespace, gvr, kind) if err != nil { return nil, err } @@ -221,9 +231,9 @@ func listPodsForSlinkySetSelectors( if _, parseErr := labels.Parse(set.selector); parseErr != nil { return nil, errors.Wrap(errors.ErrCodeInternal, fmt.Sprintf("invalid %s selector for %s/%s: %q", - kind, slinkySlurmNamespace, set.name, set.selector), parseErr) + kind, namespace, set.name, set.selector), parseErr) } - podList, listErr := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{ + podList, listErr := ctx.Clientset.CoreV1().Pods(namespace).List(ctx.Ctx, metav1.ListOptions{ LabelSelector: set.selector, }) if listErr != nil { @@ -235,8 +245,8 @@ func listPodsForSlinkySetSelectors( return pods, nil } -func findReadySlinkyLoginPod(ctx *validators.Context) (*corev1.Pod, error) { - pods, err := listPodsForSlinkySetSelectors(ctx, slinkyLoginSetGVR, "LoginSet") +func findReadySlinkyLoginPod(ctx *validators.Context, namespace string) (*corev1.Pod, error) { + pods, err := listPodsForSlinkySetSelectors(ctx, namespace, slinkyLoginSetGVR, "LoginSet") if err != nil { return nil, err } @@ -251,7 +261,7 @@ func findReadySlinkyLoginPod(ctx *validators.Context) (*corev1.Pod, error) { } return nil, errors.New(errors.ErrCodeNotFound, fmt.Sprintf("no ready login pod found for Slinky LoginSet selectors in %s:\n%s", - slinkySlurmNamespace, strings.TrimSpace(summary.String()))) + namespace, strings.TrimSpace(summary.String()))) } type slinkySetSelection struct { @@ -262,6 +272,7 @@ type slinkySetSelection struct { func listSlinkySetsForController( ctx *validators.Context, + namespace string, gvr schema.GroupVersionResource, kind string, ) ([]slinkySetSelection, error) { @@ -270,7 +281,7 @@ func listSlinkySetsForController( if err != nil { return nil, err } - list, err := dynClient.Resource(gvr).Namespace(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{}) + list, err := dynClient.Resource(gvr).Namespace(namespace).List(ctx.Ctx, metav1.ListOptions{}) if err != nil { if apierrors.IsNotFound(err) { return nil, validators.Skip(fmt.Sprintf("Slinky Slurm %s API not available", kind)) @@ -297,7 +308,7 @@ func listSlinkySetsForController( if controllerName != slinkySlurmComponent { continue } - if controllerNamespace != "" && controllerNamespace != slinkySlurmNamespace { + if controllerNamespace != "" && controllerNamespace != namespace { continue } selector, found, selectorErr := unstructured.NestedString(item.Object, "status", "selector") @@ -332,10 +343,10 @@ func podIsReady(pod *corev1.Pod) bool { return false } -func recordSlinkyInventories(ctx *validators.Context, loginPod *corev1.Pod) { - slurmPods, slurmPodsErr := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).List(ctx.Ctx, metav1.ListOptions{}) +func recordSlinkyInventories(ctx *validators.Context, namespace string, loginPod *corev1.Pod) { + slurmPods, slurmPodsErr := ctx.Clientset.CoreV1().Pods(namespace).List(ctx.Ctx, metav1.ListOptions{}) if slurmPodsErr != nil { - recordRawTextArtifact(ctx, "Slinky Slurm pods", "kubectl get pods -n slurm -o wide", + recordRawTextArtifact(ctx, "Slinky Slurm pods", fmt.Sprintf("kubectl get pods -n %s -o wide", namespace), fmt.Sprintf("failed to list pods: %v", slurmPodsErr)) } else { var podSummary strings.Builder @@ -343,12 +354,12 @@ func recordSlinkyInventories(ctx *validators.Context, loginPod *corev1.Pod) { fmt.Fprintf(&podSummary, "%-48s ready=%s phase=%s node=%s\n", pod.Name, podReadyCount(pod), pod.Status.Phase, valueOrUnknown(pod.Spec.NodeName)) } - recordRawTextArtifact(ctx, "Slinky Slurm pods", "kubectl get pods -n slurm -o wide", podSummary.String()) + recordRawTextArtifact(ctx, "Slinky Slurm pods", fmt.Sprintf("kubectl get pods -n %s -o wide", namespace), podSummary.String()) } - nodeSetPods, nodeSetErr := listSlinkyNodeSetPods(ctx) + nodeSetPods, nodeSetErr := listSlinkyNodeSetPods(ctx, namespace) if nodeSetErr != nil { - recordRawTextArtifact(ctx, "Slinky Slurm NodeSet pods", "kubectl get pods -n slurm", + recordRawTextArtifact(ctx, "Slinky Slurm NodeSet pods", fmt.Sprintf("kubectl get pods -n %s", namespace), fmt.Sprintf("failed to list NodeSet pods: %v", nodeSetErr)) } else { var nodeSetSummary strings.Builder @@ -357,7 +368,7 @@ func recordSlinkyInventories(ctx *validators.Context, loginPod *corev1.Pod) { pod.Name, podReadyCount(pod), pod.Status.Phase, valueOrUnknown(pod.Spec.NodeName)) } recordRawTextArtifact(ctx, "Slinky Slurm NodeSet pods", - "kubectl -n slurm get nodesets -o json | jq -r '.items[] | select(.apiVersion == \"slinky.slurm.net/v1beta1\") | .status.selector'", + fmt.Sprintf("kubectl -n %s get nodesets -o json | jq -r '.items[] | select(.apiVersion == \"slinky.slurm.net/v1beta1\") | .status.selector'", namespace), nodeSetSummary.String()) } @@ -366,9 +377,9 @@ func recordSlinkyInventories(ctx *validators.Context, loginPod *corev1.Pod) { loginPod.Namespace, loginPod.Name, podIsReady(loginPod), valueOrUnknown(loginPod.Spec.NodeName))) } -func recordSlinkyExecResult(ctx *validators.Context, podName string, check slinkySlurmHealthCommand, result podExecResult, execErr error) { +func recordSlinkyExecResult(ctx *validators.Context, namespace, podName string, check slinkySlurmHealthCommand, result podExecResult, execErr error) { var body strings.Builder - fmt.Fprintf(&body, "Pod: %s/%s\n", slinkySlurmNamespace, podName) + fmt.Fprintf(&body, "Pod: %s/%s\n", namespace, podName) fmt.Fprintf(&body, "Command: %s\n", strings.Join(check.command, " ")) fmt.Fprintf(&body, "ExitCode: %d\n", result.ExitCode) if execErr != nil { @@ -378,6 +389,6 @@ func recordSlinkyExecResult(ctx *validators.Context, podName string, check slink fmt.Fprintf(&body, "\nstderr:\n%s\n", result.Stderr) recordRawTextArtifact(ctx, fmt.Sprintf("Slinky Slurm %s result", check.label), - fmt.Sprintf("kubectl exec -n slurm %s -- %s", podName, strings.Join(check.command, " ")), + fmt.Sprintf("kubectl exec -n %s %s -- %s", namespace, podName, strings.Join(check.command, " ")), body.String()) } diff --git a/validators/conformance/slinky_slurm_health_check_test.go b/validators/conformance/slinky_slurm_health_check_test.go index c7bb6541b..bd3f6887b 100644 --- a/validators/conformance/slinky_slurm_health_check_test.go +++ b/validators/conformance/slinky_slurm_health_check_test.go @@ -145,7 +145,15 @@ func TestCheckSlinkySlurmHealthExecOutcomes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - restore := replaceSlinkyExecForTest(func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) { + restore := replaceSlinkyExecForTest(func( + context.Context, + *validators.Context, + string, + string, + []string, + podExecOptions, + ) (podExecResult, error) { + return tt.result, tt.err }) defer restore() @@ -163,8 +171,17 @@ func TestCheckSlinkySlurmHealthExecOutcomes(t *testing.T) { func TestCheckSlinkySlurmHealthRunsAllHealthCommands(t *testing.T) { var gotCommands []string - restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, _ string, command []string) (podExecResult, error) { + var gotOptions []podExecOptions + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _, _ string, + command []string, + opts podExecOptions, + ) (podExecResult, error) { + gotCommands = append(gotCommands, strings.Join(command, " ")) + gotOptions = append(gotOptions, opts) return podExecResult{Stdout: strings.Join(command, " ") + "\n"}, nil }) defer restore() @@ -182,11 +199,23 @@ func TestCheckSlinkySlurmHealthRunsAllHealthCommands(t *testing.T) { if strings.Join(gotCommands, ",") != strings.Join(wantCommands, ",") { t.Fatalf("commands = %v, want %v", gotCommands, wantCommands) } + for _, got := range gotOptions { + if got.DefaultContainerAnnotation != defaultContainerAnnotation || got.PreferredContainerName != slinkyLoginPodContainerName { + t.Fatalf("pod exec options = %+v, want Slinky login pod options", got) + } + } } func TestCheckSlinkySlurmHealthDiscoversPodsFromSlinkyCRSelectors(t *testing.T) { var gotPodName string - restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, podName string, _ []string) (podExecResult, error) { + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _, podName string, + _ []string, + _ podExecOptions, + ) (podExecResult, error) { + gotPodName = podName return podExecResult{Stdout: "ok\n"}, nil }) @@ -202,6 +231,57 @@ func TestCheckSlinkySlurmHealthDiscoversPodsFromSlinkyCRSelectors(t *testing.T) } } +func TestCheckSlinkySlurmHealthUsesComponentRefNamespace(t *testing.T) { + const customNamespace = "custom-slurm" + + loginPod := readyLoginPod() + loginPod.Namespace = customNamespace + nodeSetPod := readyNodeSetPod() + nodeSetPod.Namespace = customNamespace + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-0"}} + + clientset := k8sfake.NewSimpleClientset( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: customNamespace}}, + node, + loginPod, + nodeSetPod, + ) + addSlinkyDiscovery(t, clientset) + + var gotNamespace string + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + namespace string, + _ string, + _ []string, + _ podExecOptions, + ) (podExecResult, error) { + + gotNamespace = namespace + return podExecResult{Stdout: "ok\n"}, nil + }) + defer restore() + + ctx := &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + DynamicClient: newSlinkyDynamicClient(t, defaultLoginSetInNamespace(customNamespace), defaultNodeSetInNamespace(customNamespace)), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent, Namespace: customNamespace}}, + }, + } + + err := CheckSlinkySlurmHealth(ctx) + if err != nil { + t.Fatalf("error = %v, want nil", err) + } + if gotNamespace != customNamespace { + t.Fatalf("exec namespace = %q, want %q", gotNamespace, customNamespace) + } +} + func TestCheckSlinkySlurmHealthFailsOnMalformedControllerRefName(t *testing.T) { ctx := slurmReadyTestContext(t, false) loginSet := defaultLoginSet() @@ -218,7 +298,14 @@ func TestCheckSlinkySlurmHealthFailsOnMalformedControllerRefName(t *testing.T) { } func TestCheckSlinkySlurmHealthCollectsAllCommandFailures(t *testing.T) { - restore := replaceSlinkyExecForTest(func(_ context.Context, _ *validators.Context, _, _ string, command []string) (podExecResult, error) { + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _, _ string, + command []string, + _ podExecOptions, + ) (podExecResult, error) { + joined := strings.Join(command, " ") if strings.Contains(joined, "sinfo -h -Ne -t idle,mix") { return podExecResult{Stderr: "down", ExitCode: 1}, nil @@ -270,7 +357,15 @@ func slurmCustomCRSelectorContext(t *testing.T, kwok bool) *validators.Context { } func TestCheckSlinkySlurmHealthSkipsWhenAllNodeSetPodsAreOnKWOKNodes(t *testing.T) { - restore := replaceSlinkyExecForTest(func(context.Context, *validators.Context, string, string, []string) (podExecResult, error) { + restore := replaceSlinkyExecForTest(func( + context.Context, + *validators.Context, + string, + string, + []string, + podExecOptions, + ) (podExecResult, error) { + t.Fatal("exec should not run when all NodeSet pods are on KWOK nodes") return podExecResult{}, nil }) @@ -400,6 +495,14 @@ func defaultNodeSet() *unstructured.Unstructured { return slinkySetObject("NodeSet", "slinky-slurm-worker-slinky", "app.kubernetes.io/name=slurm-nodeset") } +func defaultLoginSetInNamespace(namespace string) *unstructured.Unstructured { + return slinkySetObjectInNamespace("LoginSet", "slinky-slurm-login-slinky", "app.kubernetes.io/name=slurm-login", namespace) +} + +func defaultNodeSetInNamespace(namespace string) *unstructured.Unstructured { + return slinkySetObjectInNamespace("NodeSet", "slinky-slurm-worker-slinky", "app.kubernetes.io/name=slurm-nodeset", namespace) +} + func customLoginSet() *unstructured.Unstructured { return slinkySetObject("LoginSet", "custom-login", "app.kubernetes.io/instance=custom-login,app.kubernetes.io/name=login") } @@ -409,18 +512,22 @@ func customNodeSet() *unstructured.Unstructured { } func slinkySetObject(kind, name, selector string) *unstructured.Unstructured { + return slinkySetObjectInNamespace(kind, name, selector, slinkySlurmNamespace) +} + +func slinkySetObjectInNamespace(kind, name, selector, namespace string) *unstructured.Unstructured { obj := &unstructured.Unstructured{ Object: map[string]any{ "apiVersion": "slinky.slurm.net/v1beta1", "kind": kind, "metadata": map[string]any{ "name": name, - "namespace": slinkySlurmNamespace, + "namespace": namespace, }, "spec": map[string]any{ "controllerRef": map[string]any{ "name": slinkySlurmComponent, - "namespace": slinkySlurmNamespace, + "namespace": namespace, }, }, "status": map[string]any{ From fa8e3781d83214026731bf85a8f88282630925c0 Mon Sep 17 00:00:00 2001 From: Kayne Tu Date: Mon, 22 Jun 2026 15:22:51 -0700 Subject: [PATCH 4/4] chore: address coderabbit comments --- validators/conformance/pod_exec.go | 3 +- validators/conformance/pod_exec_test.go | 3 + .../conformance/slinky_slurm_health_check.go | 31 +++- .../slinky_slurm_health_check_test.go | 164 ++++++++++++++++++ 4 files changed, 193 insertions(+), 8 deletions(-) diff --git a/validators/conformance/pod_exec.go b/validators/conformance/pod_exec.go index bc7127123..ab0be864f 100644 --- a/validators/conformance/pod_exec.go +++ b/validators/conformance/pod_exec.go @@ -133,5 +133,6 @@ func execPodCommand( result.ExitCode = exitErr.ExitStatus() return result, nil } - return result, streamErr + return result, errors.Wrap(errors.ErrCodeInternal, + fmt.Sprintf("pod exec stream failed for %s/%s", namespace, podName), streamErr) } diff --git a/validators/conformance/pod_exec_test.go b/validators/conformance/pod_exec_test.go index 4ff396a0a..184107ca7 100644 --- a/validators/conformance/pod_exec_test.go +++ b/validators/conformance/pod_exec_test.go @@ -276,6 +276,9 @@ func TestExecPodCommandReturnsStreamError(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "stream failed") { t.Fatalf("error = %v, want stream failure", err) } + if !strings.Contains(err.Error(), "pod exec stream failed for slurm/login-0") { + t.Fatalf("error = %v, want pod context", err) + } } func podExecHTTPContext(t *testing.T, pod corev1.Pod) *validators.Context { diff --git a/validators/conformance/slinky_slurm_health_check.go b/validators/conformance/slinky_slurm_health_check.go index fb7d029f8..6e19f55ef 100644 --- a/validators/conformance/slinky_slurm_health_check.go +++ b/validators/conformance/slinky_slurm_health_check.go @@ -139,6 +139,12 @@ func resolveSlinkySlurmNamespace(ctx *validators.Context) string { func runSlinkySlurmHealthCommands(ctx *validators.Context, namespace, loginPodName string) []string { var failures []string for _, check := range slinkySlurmHealthCommands { + select { + case <-ctx.Ctx.Done(): + failures = append(failures, fmt.Sprintf("context canceled: %v", ctx.Ctx.Err())) + return failures + default: + } result, execErr := slinkyExecCommand( ctx.Ctx, ctx, namespace, loginPodName, check.command, slinkyLoginPodExecOptions) recordSlinkyExecResult(ctx, namespace, loginPodName, check, result, execErr) @@ -204,7 +210,7 @@ func skipIfAllNodeSetPodsAreKWOK(ctx *validators.Context, namespace string) erro kwok++ } } - if resolved > 0 && kwok == resolved { + if resolved == len(pods) && kwok == resolved { return validators.Skip("Slinky NodeSet pods are on KWOK nodes; skipping Slurm health validation") } return nil @@ -252,12 +258,22 @@ func findReadySlinkyLoginPod(ctx *validators.Context, namespace string) (*corev1 } var summary strings.Builder - for _, pod := range pods { + var selected *corev1.Pod + for i := range pods { + pod := &pods[i] fmt.Fprintf(&summary, "%s phase=%s ready=%t node=%s\n", - pod.Name, pod.Status.Phase, podIsReady(&pod), valueOrUnknown(pod.Spec.NodeName)) - if pod.Status.Phase == corev1.PodRunning && podIsReady(&pod) { - return &pod, nil + pod.Name, pod.Status.Phase, podIsReady(pod), valueOrUnknown(pod.Spec.NodeName)) + if pod.DeletionTimestamp != nil || pod.Status.Phase == corev1.PodFailed { + continue } + if pod.Status.Phase == corev1.PodRunning && podIsReady(pod) && + (selected == nil || pod.CreationTimestamp.After(selected.CreationTimestamp.Time)) { + + selected = pod + } + } + if selected != nil { + return selected, nil } return nil, errors.New(errors.ErrCodeNotFound, fmt.Sprintf("no ready login pod found for Slinky LoginSet selectors in %s:\n%s", @@ -283,10 +299,11 @@ func listSlinkySetsForController( } list, err := dynClient.Resource(gvr).Namespace(namespace).List(ctx.Ctx, metav1.ListOptions{}) if err != nil { + code := errors.ErrCodeInternal if apierrors.IsNotFound(err) { - return nil, validators.Skip(fmt.Sprintf("Slinky Slurm %s API not available", kind)) + code = errors.ErrCodeNotFound } - return nil, errors.Wrap(errors.ErrCodeInternal, fmt.Sprintf("failed to list Slinky Slurm %ss", kind), err) + return nil, errors.Wrap(code, fmt.Sprintf("failed to list Slinky Slurm %ss in namespace %s", kind, namespace), err) } selected := make([]slinkySetSelection, 0, len(list.Items)) diff --git a/validators/conformance/slinky_slurm_health_check_test.go b/validators/conformance/slinky_slurm_health_check_test.go index bd3f6887b..643233517 100644 --- a/validators/conformance/slinky_slurm_health_check_test.go +++ b/validators/conformance/slinky_slurm_health_check_test.go @@ -24,6 +24,7 @@ import ( v1 "github.com/NVIDIA/aicr/pkg/validator/v1" "github.com/NVIDIA/aicr/validators" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" @@ -33,6 +34,7 @@ import ( "k8s.io/client-go/kubernetes" k8sfake "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/rest" + k8stesting "k8s.io/client-go/testing" ) var ( @@ -129,6 +131,23 @@ func TestCheckSlinkySlurmHealthSkipsWhenSlinkyAPIUnavailable(t *testing.T) { } } +func TestCheckSlinkySlurmHealthFailsWhenSlinkyNamespaceMissing(t *testing.T) { + ctx := slurmReadyTestContext(t, false) + dynClient := newSlinkyDynamicClient(t) + dynClient.PrependReactor("list", "*", func(k8stesting.Action) (bool, runtime.Object, error) { + return true, nil, apierrors.NewNotFound(schema.GroupResource{Resource: "namespaces"}, slinkySlurmNamespace) + }) + ctx.DynamicClient = dynClient + + err := CheckSlinkySlurmHealth(ctx) + if err == nil || !strings.Contains(err.Error(), "failed to list Slinky Slurm NodeSets in namespace slurm") { + t.Fatalf("error = %v, want namespace list failure", err) + } + if strings.Contains(strings.ToLower(err.Error()), "skip") { + t.Fatalf("error = %v, want real failure not skip", err) + } +} + func TestCheckSlinkySlurmHealthExecOutcomes(t *testing.T) { errBoom := errors.New(errors.ErrCodeInternal, "exec failed") tests := []struct { @@ -206,6 +225,36 @@ func TestCheckSlinkySlurmHealthRunsAllHealthCommands(t *testing.T) { } } +func TestCheckSlinkySlurmHealthStopsWhenContextCanceled(t *testing.T) { + ctx := slurmReadyTestContext(t, false) + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx.Ctx = runCtx + + var execCount int + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _, _ string, + _ []string, + _ podExecOptions, + ) (podExecResult, error) { + + execCount++ + cancel() + return podExecResult{Stdout: "ok\n"}, nil + }) + defer restore() + + err := CheckSlinkySlurmHealth(ctx) + if err == nil || !strings.Contains(err.Error(), "context canceled") { + t.Fatalf("error = %v, want context canceled failure", err) + } + if execCount != 1 { + t.Fatalf("exec count = %d, want 1", execCount) + } +} + func TestCheckSlinkySlurmHealthDiscoversPodsFromSlinkyCRSelectors(t *testing.T) { var gotPodName string restore := replaceSlinkyExecForTest(func( @@ -282,6 +331,65 @@ func TestCheckSlinkySlurmHealthUsesComponentRefNamespace(t *testing.T) { } } +func TestCheckSlinkySlurmHealthSelectsNewestReadyLoginPod(t *testing.T) { + olderReady := readyLoginPod() + olderReady.Name = "slinky-login-old" + olderReady.CreationTimestamp = metav1.Unix(100, 0) + + terminatingReady := readyLoginPod() + terminatingReady.Name = "slinky-login-terminating" + terminatingReady.CreationTimestamp = metav1.Unix(300, 0) + deletionTime := metav1.Unix(400, 0) + terminatingReady.DeletionTimestamp = &deletionTime + + newerReady := readyLoginPod() + newerReady.Name = "slinky-login-new" + newerReady.CreationTimestamp = metav1.Unix(200, 0) + + clientset := k8sfake.NewSimpleClientset( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: slinkySlurmNamespace}}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-0"}}, + olderReady, + terminatingReady, + newerReady, + readyNodeSetPod(), + ) + addSlinkyDiscovery(t, clientset) + + var gotPodName string + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _ string, + podName string, + _ []string, + _ podExecOptions, + ) (podExecResult, error) { + + gotPodName = podName + return podExecResult{Stdout: "ok\n"}, nil + }) + defer restore() + + ctx := &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + DynamicClient: newSlinkyDynamicClient(t, defaultLoginSet(), defaultNodeSet()), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + } + + err := CheckSlinkySlurmHealth(ctx) + if err != nil { + t.Fatalf("error = %v, want nil", err) + } + if gotPodName != newerReady.Name { + t.Fatalf("exec pod = %q, want %q", gotPodName, newerReady.Name) + } +} + func TestCheckSlinkySlurmHealthFailsOnMalformedControllerRefName(t *testing.T) { ctx := slurmReadyTestContext(t, false) loginSet := defaultLoginSet() @@ -377,6 +485,62 @@ func TestCheckSlinkySlurmHealthSkipsWhenAllNodeSetPodsAreOnKWOKNodes(t *testing. } } +func TestCheckSlinkySlurmHealthDoesNotSkipWhenNodeSetPodIsUnbound(t *testing.T) { + kwokNode := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "kwok-node-0", + Annotations: map[string]string{kwokNodeAnnotation: "fake"}, + }, + } + kwokPod := readyNodeSetPod() + kwokPod.Name = "slinky-nodeset-kwok" + kwokPod.Spec.NodeName = kwokNode.Name + unboundPod := readyNodeSetPod() + unboundPod.Name = "slinky-nodeset-unbound" + unboundPod.Spec.NodeName = "" + + clientset := k8sfake.NewSimpleClientset( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: slinkySlurmNamespace}}, + kwokNode, + readyLoginPod(), + kwokPod, + unboundPod, + ) + addSlinkyDiscovery(t, clientset) + + var execRan bool + restore := replaceSlinkyExecForTest(func( + _ context.Context, + _ *validators.Context, + _, _ string, + _ []string, + _ podExecOptions, + ) (podExecResult, error) { + + execRan = true + return podExecResult{Stdout: "ok\n"}, nil + }) + defer restore() + + ctx := &validators.Context{ + Ctx: context.Background(), + Clientset: clientset, + DynamicClient: newSlinkyDynamicClient(t, defaultLoginSet(), defaultNodeSet()), + RESTConfig: &rest.Config{Host: "https://example.test"}, + ValidationInput: &v1.ValidationInput{ + ComponentRefs: []recipe.ComponentRef{{Name: slinkySlurmComponent}}, + }, + } + + err := CheckSlinkySlurmHealth(ctx) + if err != nil { + t.Fatalf("error = %v, want nil", err) + } + if !execRan { + t.Fatal("exec did not run; unbound NodeSet pod must prevent KWOK skip") + } +} + func TestCheckSlinkySlurmHealthFailsWithoutReadyLoginPod(t *testing.T) { ctx := slurmReadyTestContext(t, false) err := ctx.Clientset.CoreV1().Pods(slinkySlurmNamespace).Delete(ctx.Ctx, "slinky-login-0", metav1.DeleteOptions{})