diff --git a/app.go b/app.go index dbd512f..7020862 100644 --- a/app.go +++ b/app.go @@ -233,35 +233,46 @@ func (app *App) RegisterIntervals(intervals ...Interval) { } } +func (app *App) registerEntityListener(etl EntityListener) { + if etl.delay != 0 && etl.toState == "" { + slog.Error("EntityListener error: you have to use ToState() when using Duration()") + panic(ErrInvalidArgs) + } + + for _, entity := range etl.entityIds { + app.entityListeners[entity] = append(app.entityListeners[entity], &etl) + } +} + func (app *App) RegisterEntityListeners(etls ...EntityListener) { for _, etl := range etls { - etl := etl - if etl.delay != 0 && etl.toState == "" { - slog.Error("EntityListener error: you have to use ToState() when using Duration()") - panic(ErrInvalidArgs) - } + app.registerEntityListener(etl) + } +} - for _, entity := range etl.entityIds { - if elList, ok := app.entityListeners[entity]; ok { - app.entityListeners[entity] = append(elList, &etl) - } else { - app.entityListeners[entity] = []*EntityListener{&etl} - } +func (app *App) registerEventListener(evl EventListener) { + for _, eventType := range evl.eventTypes { + elList, ok := app.eventListeners[eventType] + if !ok { + // We're not listening to that event type yet. Ask HA to + // send them to us, and when they arrive, call any event + // listeners for that type (including any that are + // registered in the future). + eventType := eventType + app.conn.SubscribeToEventType( + eventType, + func(msg websocket.ChanMsg) { + go app.callEventListeners(eventType, msg) + }, + ) } + app.eventListeners[eventType] = append(elList, &evl) } } func (app *App) RegisterEventListeners(evls ...EventListener) { for _, evl := range evls { - evl := evl - for _, eventType := range evl.eventTypes { - if elList, ok := app.eventListeners[eventType]; ok { - app.eventListeners[eventType] = append(elList, &evl) - } else { - websocket.SubscribeToEventType(eventType, app.conn) - app.eventListeners[eventType] = []*EventListener{&evl} - } - } + app.registerEventListener(evl) } } @@ -316,7 +327,11 @@ func (app *App) Start() { go app.runScheduledActions(app.ctx) // subscribe to state_changed events - app.entitySubscription = websocket.SubscribeToStateChangedEvents(app.conn) + app.entitySubscription = app.conn.SubscribeToStateChangedEvents( + func(msg websocket.ChanMsg) { + go app.callEntityListeners(msg.Raw) + }, + ) // entity listeners runOnStartup for eid, etls := range app.entityListeners { @@ -342,20 +357,9 @@ func (app *App) Start() { } } - // entity listeners and event listeners - elChan := make(chan websocket.ChanMsg) - go app.conn.ListenWebsocket(elChan) - - for { - msg, ok := <-elChan - if !ok { - break - } - if app.entitySubscription.ID() == msg.Id { - go callEntityListeners(app, msg.Raw) - } else { - go callEventListeners(app, msg) - } + // Start listen on the connection for incoming messages: + if err := app.conn.Run(); err != nil { + slog.Error("Error reading from websocket", "err", err) } } diff --git a/entitylistener.go b/entitylistener.go index e727539..a6e45c3 100644 --- a/entitylistener.go +++ b/entitylistener.go @@ -49,16 +49,18 @@ type stateChangedMsg struct { ID int `json:"id"` Type string `json:"type"` Event struct { - Data struct { - EntityID string `json:"entity_id"` - NewState msgState `json:"new_state"` - OldState msgState `json:"old_state"` - } `json:"data"` - EventType string `json:"event_type"` - Origin string `json:"origin"` + Data stateData `json:"data"` + EventType string `json:"event_type"` + Origin string `json:"origin"` } `json:"event"` } +type stateData struct { + EntityID string `json:"entity_id"` + NewState msgState `json:"new_state"` + OldState msgState `json:"old_state"` +} + type msgState struct { EntityID string `json:"entity_id"` LastChanged time.Time `json:"last_changed"` @@ -191,8 +193,52 @@ func (b elBuilder3) Build() EntityListener { return b.entityListener } +func (l *EntityListener) maybeCall(app *App, entityData EntityData, data stateData) { + // Check conditions + if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail { + return + } + if c := checkStatesMatch(l.fromState, data.OldState.State); c.fail { + return + } + if c := checkStatesMatch(l.toState, data.NewState.State); c.fail { + if l.delayTimer != nil { + l.delayTimer.Stop() + } + return + } + if c := checkThrottle(l.throttle, l.lastRan); c.fail { + return + } + if c := checkExceptionDates(l.exceptionDates); c.fail { + return + } + if c := checkExceptionRanges(l.exceptionRanges); c.fail { + return + } + if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { + return + } + if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { + return + } + + if l.delay != 0 { + l := l + l.delayTimer = time.AfterFunc(l.delay, func() { + go l.callback(app.service, app.state, entityData) + l.lastRan = carbon.Now() + }) + return + } + + // run now if no delay set + go l.callback(app.service, app.state, entityData) + l.lastRan = carbon.Now() +} + /* Functions */ -func callEntityListeners(app *App, msgBytes []byte) { +func (app *App) callEntityListeners(msgBytes []byte) { msg := stateChangedMsg{} _ = json.Unmarshal(msgBytes, &msg) data := msg.Event.Data @@ -211,56 +257,16 @@ func callEntityListeners(app *App, msgBytes []byte) { return } - for _, l := range listeners { - // Check conditions - if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail { - continue - } - if c := checkStatesMatch(l.fromState, data.OldState.State); c.fail { - continue - } - if c := checkStatesMatch(l.toState, data.NewState.State); c.fail { - if l.delayTimer != nil { - l.delayTimer.Stop() - } - continue - } - if c := checkThrottle(l.throttle, l.lastRan); c.fail { - continue - } - if c := checkExceptionDates(l.exceptionDates); c.fail { - continue - } - if c := checkExceptionRanges(l.exceptionRanges); c.fail { - continue - } - if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { - continue - } - if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { - continue - } - - entityData := EntityData{ - TriggerEntityId: eid, - FromState: data.OldState.State, - FromAttributes: data.OldState.Attributes, - ToState: data.NewState.State, - ToAttributes: data.NewState.Attributes, - LastChanged: data.OldState.LastChanged, - } - - if l.delay != 0 { - l := l - l.delayTimer = time.AfterFunc(l.delay, func() { - go l.callback(app.service, app.state, entityData) - l.lastRan = carbon.Now() - }) - continue - } + entityData := EntityData{ + TriggerEntityId: eid, + FromState: data.OldState.State, + FromAttributes: data.OldState.Attributes, + ToState: data.NewState.State, + ToAttributes: data.NewState.Attributes, + LastChanged: data.OldState.LastChanged, + } - // run now if no delay set - go l.callback(app.service, app.state, entityData) - l.lastRan = carbon.Now() + for _, l := range listeners { + l.maybeCall(app, entityData, data) } } diff --git a/eventListener.go b/eventListener.go index a9f9e24..2f1b92e 100644 --- a/eventListener.go +++ b/eventListener.go @@ -1,7 +1,6 @@ package gomeassistant import ( - "encoding/json" "fmt" "time" @@ -133,48 +132,45 @@ func (b eventListenerBuilder3) Build() EventListener { return b.eventListener } -type BaseEventMsg struct { - Event struct { - EventType string `json:"event_type"` - } `json:"event"` +func (l *EventListener) maybeCall(app *App, eventData EventData) { + // Check conditions + if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail { + return + } + if c := checkThrottle(l.throttle, l.lastRan); c.fail { + return + } + if c := checkExceptionDates(l.exceptionDates); c.fail { + return + } + if c := checkExceptionRanges(l.exceptionRanges); c.fail { + return + } + if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { + return + } + if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { + return + } + + go l.callback(app.service, app.state, eventData) + l.lastRan = carbon.Now() } /* Functions */ -func callEventListeners(app *App, msg websocket.ChanMsg) { - baseEventMsg := BaseEventMsg{} - _ = json.Unmarshal(msg.Raw, &baseEventMsg) - listeners, ok := app.eventListeners[baseEventMsg.Event.EventType] +func (app *App) callEventListeners(eventType string, msg websocket.ChanMsg) { + listeners, ok := app.eventListeners[eventType] if !ok { // no listeners registered for this event type return } + eventData := EventData{ + Type: eventType, + RawEventJSON: msg.Raw, + } + for _, l := range listeners { - // Check conditions - if c := checkWithinTimeRange(l.betweenStart, l.betweenEnd); c.fail { - continue - } - if c := checkThrottle(l.throttle, l.lastRan); c.fail { - continue - } - if c := checkExceptionDates(l.exceptionDates); c.fail { - continue - } - if c := checkExceptionRanges(l.exceptionRanges); c.fail { - continue - } - if c := checkEnabledEntity(app.state, l.enabledEntities); c.fail { - continue - } - if c := checkDisabledEntity(app.state, l.disabledEntities); c.fail { - continue - } - - eventData := EventData{ - Type: baseEventMsg.Event.EventType, - RawEventJSON: msg.Raw, - } - go l.callback(app.service, app.state, eventData) - l.lastRan = carbon.Now() + l.maybeCall(app, eventData) } } diff --git a/internal/websocket/locked_conn.go b/internal/websocket/locked_conn.go index 8557b71..66da9fe 100644 --- a/internal/websocket/locked_conn.go +++ b/internal/websocket/locked_conn.go @@ -11,6 +11,26 @@ type LockedConn interface { // `LockedConn` is still active. NextMessageID() int64 + // Subscribe allocates a new message ID and subscribes + // `subscriber` to it, in the sense that the subscriber will be + // called for any incoming messages that have that ID. This + // doesn't actually interact with the server. Typically the next + // step would be to send a message with its message ID set to + // `Subscription.ID()`. + // + // The returned `Subscription` must eventually be passed at least + // once to `Unsubscribe()`, though `Unsubscribe()` can be called + // against a different `LockedConn` than the one that generated + // it. + Subscribe(subscriber Subscriber) Subscription + + // Unsubscribe terminates `subscription` at the websocket level; + // i.e., no more incoming messages will be forwarded to the + // corresponding `Subscriber`. Note that this does not interact + // with the server; it is the caller's responsibility to send it + // an "unsubscribe" command if necessary. + Unsubscribe(subscription Subscription) + // SendMessage sends the specified message over the websocket // connection. `msg` must be JSON-serializable and have the // correct format and a unique, monotonically-increasing ID, which @@ -30,6 +50,31 @@ func (lc lockedConn) NextMessageID() int64 { return lc.conn.lastMessageID } +// Subscribe implements [LockedConn.Subscribe]. +func (lc lockedConn) Subscribe(subscriber Subscriber) Subscription { + lc.conn.subscribersLock.Lock() + defer lc.conn.subscribersLock.Unlock() + + id := lc.NextMessageID() + lc.conn.subscribers[id] = subscriber + return Subscription{ + messageID: id, + } +} + +// Unsubscribe implements [LockedConn.Unsubscribe]. +func (lc lockedConn) Unsubscribe(subscription Subscription) { + if subscription.messageID == 0 { + return + } + + lc.conn.subscribersLock.Lock() + defer lc.conn.subscribersLock.Unlock() + + delete(lc.conn.subscribers, subscription.messageID) + subscription.messageID = 0 +} + // SendMessage implements [LockedConn.SendMessage]. func (lc lockedConn) SendMessage(msg any) error { if err := lc.conn.conn.WriteJSON(msg); err != nil { diff --git a/internal/websocket/reader.go b/internal/websocket/reader.go index 1dfc3ca..f1cce6f 100644 --- a/internal/websocket/reader.go +++ b/internal/websocket/reader.go @@ -18,16 +18,21 @@ type ChanMsg struct { Raw []byte } -// ListenWebsocket reads JSON-formatted messages from `conn`, partly -// deserializes them, and sends them to `c`. If there is an error, -// close `c` and return. -func (conn *Conn) ListenWebsocket(c chan<- ChanMsg) { +// Run processes incoming messages from `Conn`. It reads +// JSON-formatted messages from `conn`, partly deserializes them, and +// passes them to the subscriber that has subscribed to that message +// ID (if any). If there is an error, return the error and stop +// listening. +// +// Note that the subscribers are invoked synchronously, in the same +// order as the messages arrived, and only one is run at a time. If +// the subscriber wants processing to happen in the background, it +// must spawn a goroutine itself. +func (conn *Conn) Run() error { for { bytes, err := conn.readMessage() if err != nil { - slog.Error("Error reading from websocket", "err", err) - close(c) - return + return err } base := BaseMessage{ @@ -45,6 +50,10 @@ func (conn *Conn) ListenWebsocket(c chan<- ChanMsg) { Raw: bytes, } - c <- chanMsg + // If a subscriber has been registered for this message ID, + // then call it, too: + if subr, ok := conn.getSubscriber(base.Id); ok { + subr(chanMsg) + } } } diff --git a/internal/websocket/subscriptions.go b/internal/websocket/subscriptions.go new file mode 100644 index 0000000..2cad7fa --- /dev/null +++ b/internal/websocket/subscriptions.go @@ -0,0 +1,75 @@ +package websocket + +import ( + "fmt" + "log/slog" +) + +// Subscription represents a websocket-level subscription to a +// particular message ID. +type Subscription struct { + messageID int64 +} + +// MessageID returns the message ID that this subscription is +// subscribed to. +func (sub Subscription) MessageID() int64 { + return sub.messageID +} + +// Subscriber is called synchronously when a message is received that +// matches its subscription's message ID. +type Subscriber func(msg ChanMsg) + +// NoopSubscriber is a `Subscriber` that does nothing. +func NoopSubscriber(_ ChanMsg) {} + +// getSubscriber returns the subscriber, if any, that is subscribed to +// the specified message ID. +func (conn *Conn) getSubscriber(messageID int64) (Subscriber, bool) { + conn.subscribersLock.RLock() + defer conn.subscribersLock.RUnlock() + + subscriber, ok := conn.subscribers[messageID] + return subscriber, ok +} + +type SubEvent struct { + Id int64 `json:"id"` + Type string `json:"type"` + EventType string `json:"event_type"` +} + +func (conn *Conn) SubscribeToEventType(eventType string, subr Subscriber) Subscription { + var subn Subscription + err := conn.Send( + func(lc LockedConn) error { + subn = lc.Subscribe(subr) + e := SubEvent{ + Id: subn.messageID, + Type: "subscribe_events", + EventType: eventType, + } + + if err := lc.SendMessage(e); err != nil { + lc.Unsubscribe(subn) + return fmt.Errorf("error writing to websocket: %w", err) + } + // m, _ := ReadMessage(ctx, conn) + // log.Default().Println(string(m)) + + return nil + }, + ) + + if err != nil { + slog.Error(err.Error()) + panic(err) + } + + return subn +} + +func (conn *Conn) SubscribeToStateChangedEvents(subr Subscriber) Subscription { + return conn.SubscribeToEventType("state_changed", subr) +} diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index b91445f..25aa117 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -8,7 +8,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "log/slog" "net/url" "sync" @@ -27,6 +26,13 @@ type Conn struct { conn *websocket.Conn writeLock sync.Mutex lastMessageID int64 + + // subscribersLock guards access to `subscribers`. + subscribersLock sync.RWMutex + + // subscribers is a map from message ID to the subscriber that is + // subscribed to messages with that ID. + subscribers map[int64]Subscriber } func (conn *Conn) readMessage() ([]byte, error) { @@ -59,7 +65,8 @@ func NewConn( } conn := Conn{ - conn: gConn, + conn: gConn, + subscribers: make(map[int64]Subscriber), } // Read auth_required message @@ -120,52 +127,3 @@ func (conn *Conn) verifyAuthResponse(ctx context.Context) error { return nil } - -type SubEvent struct { - Id int64 `json:"id"` - Type string `json:"type"` - EventType string `json:"event_type"` -} - -// Subscription represents a websocket-level subscription to a -// particular message ID. -type Subscription struct { - id int64 -} - -func (sub Subscription) ID() int64 { - return sub.id -} - -func SubscribeToStateChangedEvents(conn *Conn) Subscription { - return SubscribeToEventType("state_changed", conn) -} - -func SubscribeToEventType(eventType string, conn *Conn) Subscription { - var id int64 - err := conn.Send( - func(lc LockedConn) error { - id = lc.NextMessageID() - e := SubEvent{ - Id: id, - Type: "subscribe_events", - EventType: eventType, - } - - if err := lc.SendMessage(e); err != nil { - return fmt.Errorf("error writing to websocket: %w", err) - } - // m, _ := ReadMessage(ctx, conn) - // log.Default().Println(string(m)) - - return nil - }, - ) - - if err != nil { - slog.Error(err.Error()) - panic(err) - } - - return Subscription{id} -}