Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Dockerfile.alpine
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
FROM golang:1.21.4-alpine3.18@sha256:110b07af87238fbdc5f1df52b00927cf58ce3de358eeeb1854f10a8b5e5e1411 AS build
FROM golang:1.24-alpine AS build

WORKDIR /go/src/github.com/juanfont/headscale/

ARG BUILD_VERSION

COPY . .

RUN test -n "${BUILD_VERSION}" \
&& apk update \
RUN apk update \
&& apk upgrade -a \
&& apk add --no-cache ca-certificates curl gcc musl-dev \
&& update-ca-certificates \
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/juanfont/headscale

go 1.20
go 1.24

require (
github.com/AlecAivazis/survey/v2 v2.3.6
Expand Down
43 changes: 22 additions & 21 deletions hscontrol/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -115,10 +116,7 @@ func (h *Headscale) LoadACLPolicyFromBytes(acl []byte) error {
}

func (h *Headscale) UpdateACLRules() error {
machines, err := h.ListMachines()
if err != nil {
return err
}
machines := h.GetPrefetchedMachines()

if h.aclPolicy == nil {
return errEmptyPolicy
Expand Down Expand Up @@ -216,16 +214,14 @@ func (pol *ACLPolicy) generateFilterRules(
}

func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
var err error
rules := []*tailcfg.SSHRule{}

if h.aclPolicy == nil {
return nil, errEmptyPolicy
}

machines, err := h.ListMachines()
if err != nil {
return nil, err
}
machines := h.GetPrefetchedMachines()

acceptAction := tailcfg.SSHAction{
Message: "",
Expand Down Expand Up @@ -562,7 +558,7 @@ func excludeCorrectlyTaggedNodes(
for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
ns := append(owners, user)
if contains(ns, user) {
if slices.Contains(ns, user) {
tags = append(tags, tag)
}
}
Expand All @@ -572,7 +568,7 @@ func excludeCorrectlyTaggedNodes(

found := false
for _, t := range hi.RequestTags {
if contains(tags, t) {
if slices.Contains(tags, t) {
found = true

break
Expand Down Expand Up @@ -639,15 +635,19 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err

func filterMachinesByUser(machines []Machine, user string) []Machine {
out := []Machine{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
for index := 0; index < len(machines); index++ {
//for _, machine := range machines {
//if machine.User.Name == user {
if machines[index].User.Name == user {
out = append(out, machines[index])
//out = append(out, machine)
}
}

return out
}

var invalidTagErr = errors.New("invalid tag")

// getTagOwners will return a list of user. An owner can be either a user or a group
// a group cannot be composed of groups.
func getTagOwners(
Expand All @@ -658,11 +658,7 @@ func getTagOwners(
var owners []string
ows, ok := pol.TagOwners[tag]
if !ok {
return []string{}, fmt.Errorf(
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
errInvalidTag,
tag,
)
return []string{}, invalidTagErr
}
for _, owner := range ows {
if isGroup(owner) {
Expand Down Expand Up @@ -746,7 +742,8 @@ func (pol *ACLPolicy) getIPsFromTag(

// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
//if contains(machine.ForcedTags, alias) {
if slices.Contains(machine.ForcedTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
Expand Down Expand Up @@ -775,7 +772,8 @@ func (pol *ACLPolicy) getIPsFromTag(
machines := filterMachinesByUser(machines, user)
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
//if contains(hi.RequestTags, alias) {
if slices.Contains(hi.RequestTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
Expand All @@ -792,6 +790,9 @@ func (pol *ACLPolicy) getIPsForUser(
build := netipx.IPSetBuilder{}

filteredMachines := filterMachinesByUser(machines, user)
if len(filteredMachines) == 0 {
return nil, nil //nolint
}
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)

// shortcurcuit if we have no machines to get ips from.
Expand Down
4 changes: 4 additions & 0 deletions hscontrol/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ func (h *Headscale) CreateAPIKey(
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return "", nil, fmt.Errorf("failed to load machines from database: %w", err)
}

return keyStr, &key, nil
}

Expand Down
14 changes: 10 additions & 4 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap
DERPServer *DERPServer

aclPolicy *ACLPolicy
aclRules []tailcfg.FilterRule
sshPolicy *tailcfg.SSHPolicy
aclPolicy *ACLPolicy
aclRules []tailcfg.FilterRule
sshPolicy *tailcfg.SSHPolicy
prefetchedMachines []Machine

lastStateChange *xsync.MapOf[string, time.Time]

Expand All @@ -97,7 +98,8 @@ type Headscale struct {

registrationCache *cache.Cache

ipAllocationMutex sync.Mutex
ipAllocationMutex sync.Mutex
prefetchMachineMutex sync.RWMutex

shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup
Expand Down Expand Up @@ -528,6 +530,10 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
var err error
if err = h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from db : %w", err)
}

if err = h.loadACLPolicy(); err != nil {
return fmt.Errorf("failed to load ACL policy: %w", err)
}
Expand Down
83 changes: 75 additions & 8 deletions hscontrol/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,7 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
// If ACLs rules are defined, filter visible host list with the ACLs
// else use the classic user scope
if h.aclPolicy != nil {
var machines []Machine
machines, err = h.ListMachines()
if err != nil {
log.Error().Err(err).Msg("Error retrieving list of machines")

return Machines{}, err
}
machines := h.GetPrefetchedMachines()
peers = h.filterMachinesByACL(machine, machines)
} else {
peers, err = h.ListPeers(machine)
Expand Down Expand Up @@ -424,6 +418,10 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error {
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

Expand All @@ -438,6 +436,10 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
return fmt.Errorf("failed to expire machine in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

Expand Down Expand Up @@ -465,6 +467,10 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
return fmt.Errorf("failed to rename machine in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

Expand All @@ -484,6 +490,10 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

Expand All @@ -498,15 +508,25 @@ func (h *Headscale) DeleteMachine(machine *Machine) error {
return err
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

func (h *Headscale) TouchMachine(machine *Machine) error {
return h.db.Updates(Machine{
err := h.db.Updates(Machine{
ID: machine.ID,
LastSeen: machine.LastSeen,
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
}).Error

if err != nil {
return err
}
h.UpdateMachineInCache(*machine)
return nil
}

// HardDeleteMachine hard deletes a Machine from the database.
Expand All @@ -520,6 +540,10 @@ func (h *Headscale) HardDeleteMachine(machine *Machine) error {
return err
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to load machines from database: %w", err)
}

return nil
}

Expand Down Expand Up @@ -863,6 +887,16 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
// Registration of expired machine with different user
if registrationMachine.ID != 0 &&
registrationMachine.UserID != user.ID {
log.Info().
Str("error registering nodeKey", nodeKey.ShortString()).
Str("request userName", userName).
Uint64("cache registration machine id:", registrationMachine.ID).
Uint("cache registration machine user id:", registrationMachine.UserID).
Uint("db user id: ", user.ID).
Int("registration cache item count ", h.registrationCache.ItemCount()).
Str("registration cache items ", fmt.Sprintf("%v", h.registrationCache.Items())).
Msg("Registration failure due to key already registered")

return nil, ErrDifferentRegisteredUser
}

Expand Down Expand Up @@ -908,6 +942,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return nil, fmt.Errorf("failed to load machines from database: %w", err)
}

log.Trace().
Caller().
Str("machine", machine.Hostname).
Expand Down Expand Up @@ -939,6 +977,10 @@ func (h *Headscale) RegisterMachine(machine Machine,
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return nil, fmt.Errorf("failed to load machines from database: %w", err)
}

log.Trace().
Caller().
Str("machine", machine.Hostname).
Expand Down Expand Up @@ -1203,6 +1245,31 @@ func (h *Headscale) GenerateGivenName(machineKey string, suppliedName string) (s
return givenName, nil
}

func (h *Headscale) GetPrefetchedMachines() []Machine {
h.prefetchMachineMutex.RLock()
defer h.prefetchMachineMutex.RUnlock()
machinesCopy := make([]Machine, len(h.prefetchedMachines))
copy(machinesCopy, h.prefetchedMachines)
return machinesCopy
}

func (h *Headscale) LoadPrefetchMachinesFromDB() (err error) {
h.prefetchMachineMutex.Lock()
defer h.prefetchMachineMutex.Unlock()
h.prefetchedMachines, err = h.ListMachines()
return err
}

func (h *Headscale) UpdateMachineInCache(machine Machine) {
h.prefetchMachineMutex.Lock()
defer h.prefetchMachineMutex.Unlock()
for idx, cacheMachine := range h.prefetchedMachines {
if cacheMachine.ID == machine.ID {
h.prefetchedMachines[idx] = machine
}
}
}

func (machines Machines) FilterByIP(ip netip.Addr) Machines {
found := make(Machines, 0)

Expand Down
8 changes: 8 additions & 0 deletions hscontrol/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return nil, fmt.Errorf("failed to load machines from database: %w", err)
}

return &key, nil
}

Expand Down Expand Up @@ -171,6 +175,10 @@ func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}

if err := h.LoadPrefetchMachinesFromDB(); err != nil {
return fmt.Errorf("failed to create key in the database: %w", err)
}

return nil
}

Expand Down
Loading
Loading