diff --git a/.gitignore b/.gitignore index 1b167ff..f0b6b21 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /dist _release/attiny-firmware.hex _release/attiny-firmware.hex.sha256 - +/tc2-hat-controller diff --git a/.goreleaser.yml b/.goreleaser.yml index 8909a10..b88333a 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -77,6 +77,8 @@ nfpms: dst: /usr/bin/tc2-hat-rtc - src: _release/tc2-hat-temp dst: /usr/bin/tc2-hat-temp + - src: _release/tc2-hat-trap-cli + dst: /usr/bin/tc2-hat-trap-cli dependencies: #- python3-pip diff --git a/_release/tc2-hat-trap-cli b/_release/tc2-hat-trap-cli new file mode 100644 index 0000000..eb73af0 --- /dev/null +++ b/_release/tc2-hat-trap-cli @@ -0,0 +1,2 @@ +#!/bin/bash +exec /usr/bin/tc2-hat-controller trap-cli "$@" diff --git a/cmd/tc2-hat-controller/main.go b/cmd/tc2-hat-controller/main.go index f495ab6..aab802a 100644 --- a/cmd/tc2-hat-controller/main.go +++ b/cmd/tc2-hat-controller/main.go @@ -12,6 +12,7 @@ import ( rp2040 "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-rp2040" rtc "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-rtc" temp "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-temp" + trapcli "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-trap-cli" ) var ( @@ -60,6 +61,8 @@ func runMain() error { err = rtc.Run(args, version) case "temp": err = temp.Run(args, version) + case "trap-cli": + err = trapcli.Run(args, version) default: err = fmt.Errorf("unknown subcommand: %s", subcommand) } diff --git a/internal/tc2-hat-comms/bluetooth.go b/internal/tc2-hat-comms/bluetooth.go deleted file mode 100644 index b2d7f9d..0000000 --- a/internal/tc2-hat-comms/bluetooth.go +++ /dev/null @@ -1,10 +0,0 @@ -// This section deals with communication with peripherals over bluetooth. - -package comms - -/* -func processBluetooth() error { - // TODO - return nil -} -*/ diff --git a/internal/tc2-hat-comms/json-out.go b/internal/tc2-hat-comms/json-out.go new file mode 100644 index 0000000..7f0af79 --- /dev/null +++ b/internal/tc2-hat-comms/json-out.go @@ -0,0 +1,98 @@ +// Output mode: sends events out over serial in JSON format. + +package comms + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper" + "github.com/TheCacophonyProject/tc2-hat-controller/tracks" +) + +type ClassificationData struct { + Species tracks.Species `json:"species"` + Confidence int32 `json:"confidence"` +} + +type jsonOut struct { + Type string `json:"type,omitempty"` + Data ClassificationData `json:"data"` +} + +func processJSONOut(config *CommsConfig, testClassification *TestClassification, trackingSignals chan event, port *serialhelper.SerialPort) error { + if testClassification != nil { + log.Println("Sending a test classification over UART") + + species := tracks.Species{ + testClassification.Animal: int32(testClassification.Confidence), + } + + classificationData := ClassificationData{ + Species: species, + Confidence: int32(testClassification.Confidence), + } + + message := jsonOut{ + Type: "classification", + Data: classificationData, + } + payload, err := json.Marshal(message) + if err != nil { + return err + } + + log.Printf("Sending payload: '%s'", payload) + return port.Write(append(payload, '\r', '\n')) + } + + for { + log.Debug("Waiting") + for e := range trackingSignals { + switch v := e.(type) { + case trackingEvent: + fmt.Println("Tracking event:", v.Species) + err := processTrackingEvent(v, port) + if err != nil { + log.Error("Error processing tracking event:", err) + } + default: + log.Debug("Not processing event:", v) + continue + } + } + } +} + +func processTrackingEvent(t trackingEvent, port *serialhelper.SerialPort) error { + log.Debugf("Found new track: %+v", t) + + species := tracks.Species{} + for k, v := range t.Species { + if v > 0 { + species[k] = v + } + } + + message := jsonOut{ + Type: "classification", + Data: ClassificationData{ + Species: species, + Confidence: t.Confidence, + }, + } + + payload, err := json.Marshal(message) + if err != nil { + return err + } + + log.Printf("Sending payload: '%s'", payload) + start := time.Now() + + err = port.Write(append(payload, '\r', '\n')) + + log.Printf("Sent payload in %s", time.Since(start)) + return err +} diff --git a/internal/tc2-hat-comms/main.go b/internal/tc2-hat-comms/main.go index 6d5a5f4..25f4dcf 100644 --- a/internal/tc2-hat-comms/main.go +++ b/internal/tc2-hat-comms/main.go @@ -9,10 +9,12 @@ import ( goconfig "github.com/TheCacophonyProject/go-config" "github.com/TheCacophonyProject/go-utils/logging" + "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper" "github.com/TheCacophonyProject/tc2-hat-controller/tracks" "github.com/alexflint/go-arg" "github.com/google/go-cmp/cmp" "github.com/rjeczalik/notify" + "periph.io/x/conn/v3/gpio" ) var ( @@ -110,11 +112,6 @@ func Run(inputArgs []string, ver string) error { } } - if config.CommsOut == "uart" && config.Bluetooth { - log.Error("Can't have output set to UART and Bluetooth enabled at the same time.") - return fmt.Errorf("can't have output set to UART and Bluetooth enabled at the same time") - } - log.Info("Species to trap:\n", tracks.Species(config.TrapSpecies)) log.Info("Species to protect:\n", tracks.Species(config.ProtectSpecies)) @@ -127,8 +124,8 @@ func Run(inputArgs []string, ver string) error { } switch config.CommsOut { - case "uart": - log.Info("Running UART output.") + case "uart", "json-out": + log.Info("Running UART/json-out.") // uart comms channel listens for tracking events err = addTrackingEvents(eventsChan) @@ -136,7 +133,13 @@ func Run(inputArgs []string, ver string) error { return err } - if err := processUart(config, args.SendTestClassification, eventsChan); err != nil { + port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, config.BaudRate) + if err != nil { + return fmt.Errorf("failed to open serial port: %v", err) + } + defer port.Close() + + if err := processJSONOut(config, args.SendTestClassification, eventsChan, port); err != nil { return err } case "simple": @@ -151,182 +154,47 @@ func Run(inputArgs []string, ver string) error { if err := processSimpleOutput(config, eventsChan); err != nil { return err } - case "at-esl": - log.Info("Running AT-ESL output.") - config.BaudRate = 4800 // Force AT-ESL baud rate to be 4800 + case "trap-control": + log.Info("Running trap-control output.") - // at-esl comms channel listens for tracking *reprocessed* events - err = addTrackingReprocessedEvents(eventsChan) - if err != nil { - return err - } + // TODO, check what speed we want for this + log.Info("Forcing baud rate to 9600, this will likely change in the future.") + config.BaudRate = 9600 - if err := processATESL(config, args.SendTestClassification, eventsChan); err != nil { + // Add tracking events to the channel. + // This is so we can activate and deactivate the trap depending on the track classification. + err = addTrackingEvents(eventsChan) + if err != nil { return err } - default: - return fmt.Errorf("unknown output type '%s'", config.CommsOut) - } - - return nil - /* - - trapActiveUntil := time.Time{} - trapActive := false - - // Initialize the periph host drivers - if _, err := host.Init(); err != nil { - log.Printf("Failed to initialize periph: %v\n", err) + // Add recording start/stop events to the channel. + // This is so we can see how long it takes from a recording started to getting a track classification. + err = addRecordingEvents(eventsChan) + if err != nil { return err } - log.Info("Get lock on serial port") - if config.CommsOut == "uart" || config.CommsOut == "simple" { - serialFile, err := serialhelper.GetSerial(3, gpio.High, gpio.Low, time.Second) - if err != nil { - return err - } - defer serialhelper.ReleaseSerial(serialFile) - } - log.Info("Done") - - protectDuration := time.Minute - trapDuration := time.Duration(args.TrapStayActiveDuration) * time.Second - - var newTrack *trackingEvent - lastProtectSpeciesSighting := time.Time{} - lastTrapSpeciesSighting := time.Time{} - - for { - - now := time.Now() - newTrapActive := - (lastProtectSpeciesSighting.Add(protectDuration).Before(now) && // Nothing to protect has been seen recently. - lastTrapSpeciesSighting.Add(trapDuration).After(now)) // And something to trap has been sighted recently. - - if trapActive != newTrapActive { - trapActive = newTrapActive - - if trapActive { - log.Println("Activating trap") - } else { - log.Println("Deactivating trap") - } - - switch args.OutputType { - case "uart": - log.Info("Outputting trap active state via UART") - if err := processUart(); err != nil { - return err - } - // TODO - - case "bluetooth": - log.Info("Outputting trap active state via bluetooth") - if err := processBluetooth(); err != nil { - return err - } - // TODO - - case "digital": - log.Info("Outputting trap active state via digital signals") - //if err := processDigital(); err != nil { - // return err - //} - - default: - return fmt.Errorf("unhandled output type: %s", args.OutputType) - } - } - - var delay = 10 * time.Second - if trapActive && time.Until(trapActiveUntil) < delay { - delay = time.Until(trapActiveUntil) - } - - newTrack = nil - log.Debug("Waiting ") - select { - case t := <-trackingSignals: - newTrack = &t - log.Debugf("Found new track: %+v", newTrack) - - if newTrack.species.MatchSpeciesWithConfidence(protectSpecies) { - log.Debug("Found an animal that needs to be protected, deactivating trap") - lastProtectSpeciesSighting = time.Now() - //trapActiveUntil = time.Time{} - } else if newTrack.species.MatchSpeciesWithConfidence(trapSpecies) { - log.Debug("Found an animal that needs to be trapped, activating trap") - lastTrapSpeciesSighting = time.Now() - - //trapActiveUntil = time.Now().Add(time.Duration(args.TrapStayActiveDuration) * time.Second) - } else { - log.Debug("No animals need to be protected or trapped, not changing trap state.") - } - - case <-time.After(delay): - log.Debug("Scheduled check") - } - } - */ - - /* - for t := range tracks { - log.Infof("Found track: %+v", t) - } - - // Start dbus to listen for classification messages. - - if err := beep(); err != nil { + // Run the trap control process + if err := processTrapControl(config, eventsChan); err != nil { return err } + case "at-esl": + log.Info("Running AT-ESL output.") + config.BaudRate = 4800 // Force AT-ESL baud rate to be 4800 - log.Println("Starting UART service") - if err := startService(); err != nil { + // at-esl comms channel listens for tracking *reprocessed* events + err = addTrackingReprocessedEvents(eventsChan) + if err != nil { return err } - trapActive = false - if err := sendTrapActiveState(trapActive); err != nil { + if err := processATESL(config, args.SendTestClassification, eventsChan); err != nil { return err } - - for { - waitUntil := time.Now().Add(5 * time.Second) - if trapActive { - waitUntil = activateTrapUntil - } - - select { - case <-activeTrapSig: - case <-time.After(time.Until(waitUntil)): - } - trapActive = time.Now().Before(activateTrapUntil) - - if err := sendTrapActiveState(trapActive); err != nil { - return err - } - } - */ -} - -/* -func checkClassification(data map[byte]byte) error { - for k, v := range data { - if k == 1 && v > 80 { - activateTrap() - } - if k == 7 && v > 80 { - activateTrap() - } + default: + return fmt.Errorf("unknown output type '%s'", config.CommsOut) } - return nil -} -func activateTrap() { - log.Println("Activating trap") - activateTrapUntil = time.Now().Add(time.Minute) - activeTrapSig <- "trap" + return nil } -*/ diff --git a/internal/tc2-hat-comms/service-monitor.go b/internal/tc2-hat-comms/service-monitor.go index c895254..220649e 100644 --- a/internal/tc2-hat-comms/service-monitor.go +++ b/internal/tc2-hat-comms/service-monitor.go @@ -38,6 +38,7 @@ type trackingEvent struct { Tracking bool LastPredictionFrame int32 ClipAgeSeconds int32 + TrackStartTime time.Time } func (t trackingEvent) isEvent() {} @@ -48,11 +49,17 @@ type batteryEvent struct { Percent float64 } +// Add tracking reprocessed events to the channel +// Tracking reprocessed events are sent once the track has finished and it gets reprocessed. +// This is useful if you are just using the events for reporting purposes and don't need to control +// something in real time. func addTrackingReprocessedEvents(eventsChan chan event) error { targetSignalName := "org.cacophony.thermalrecorder.TrackingReprocessed" return addTrackingEventsForSignal(eventsChan, targetSignalName) } +// Add tracking events to the channel +// Tracking events are sent while the track is in progress. func addTrackingEvents(eventsChan chan event) error { targetSignalName := "org.cacophony.thermalrecorder.Tracking" return addTrackingEventsForSignal(eventsChan, targetSignalName) @@ -91,7 +98,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string) log.Debugf("Received tracking event [%v]:", signal.Name) // Reprocessed signals have an additional parameter 'clip_end_time' - if len(signal.Body) < 12 { + if len(signal.Body) != 13 { log.Errorf("Unexpected signal format in body: %v", signal.Body) continue } @@ -107,9 +114,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string) log.Debugf("Tracking: %v", signal.Body[9]) log.Debugf("Last prediction frame: %v", signal.Body[10]) log.Debugf("Model Id: %v", signal.Body[11]) - if len(signal.Body) >= 13 { - log.Debugf("Clip End Time: %v", signal.Body[12]) - } + log.Debugf("Track Start Time: %v", signal.Body[12]) var modelId int32 var modelLabels []string @@ -153,16 +158,8 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string) var region [4]int32 copy(region[:], signal.Body[5].([]int32)) - // See if we have a clip end time - clipAgeSeconds := int32(0) - if len(signal.Body) >= 13 { - ts := signal.Body[12].(float64) - now := time.Now() - target := time.Unix(int64(ts), int64((ts-float64(int64(ts)))*1e9)) - - clipAgeSeconds = int32(now.Sub(target).Seconds()) - log.Debugf("Clip is %d seconds old", clipAgeSeconds) - } + nanoSeconds := signal.Body[12].(int64) * 1e6 + trackStartTime := time.Unix(0, nanoSeconds) // Finally let's build our tracking event t := trackingEvent{ @@ -177,7 +174,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string) BlankRegion: signal.Body[8].(bool), Tracking: signal.Body[9].(bool), LastPredictionFrame: signal.Body[10].(int32), - ClipAgeSeconds: clipAgeSeconds, + TrackStartTime: trackStartTime, } log.Debugf("Sending tracking event: %+v", t) @@ -189,6 +186,64 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string) return nil } +// recordingEvent is a event that is made at the start and end of a recording +type recordingEvent struct { + event + Timestamp time.Time + Recording bool +} + +// Add recording events to the channel +// These are events that are made at the start and end of a recording +func addRecordingEvents(eventsChan chan event) error { + // Listen for signals + targetSignalName := "org.cacophony.thermalrecorder.Recording" + + // Connect to the system bus + conn, err := dbus.SystemBus() + if err != nil { + log.Fatalf("Failed to connect to system bus: %v", err) + } + + // Add a match rule to listen for our dbus signals + rule := "type='signal',interface='org.cacophony.thermalrecorder'" + call := conn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, rule) + if call.Err != nil { + log.Fatalf("Failed to add match rule: %v", call.Err) + } + + // Create a channel to receive signals + c := make(chan *dbus.Signal, 10) + conn.Signal(c) + + log.Infof("Listening for D-Bus signals: %s", targetSignalName) + // Listen for signals, process and send tracking events to the channel. + go func() { + for signal := range c { + if signal.Name == targetSignalName { + log.Debug("Received Recording event.") + if len(signal.Body) != 2 { + log.Errorf("Unexpected signal format in body: %v", signal.Body) + continue + } + + // Time is given in ms since epoch, we will convert to nanoseconds so we can use the time package. + nanoSeconds := signal.Body[0].(int64) * 1e6 + recordingStartTime := time.Unix(0, nanoSeconds) + + // Send the event to the channel. + eventsChan <- recordingEvent{ + Timestamp: recordingStartTime, + Recording: signal.Body[1].(bool), + } + } + } + }() + + return nil +} + +// Add battery events to the channel. func addBatteryEvents(eventsChan chan event) error { // Listen for signals targetSignalName := "org.cacophony.attiny.Battery" @@ -255,14 +310,15 @@ func getLabels() { } bodyMap := t_call.Body[0].(map[int32][]string) - // Out model labels have id '1' .. false-postitives are the other element. + // Out model labels have id '1' .. false-positive are the other element. // e.g. [map[1:[bird cat deer ... vehicle wallaby] 1004:[animal false-positive]]] for k, v := range bodyMap { - if k == animalsList.Id { + switch k { + case animalsList.Id: animalsList.Labels = v - } else if k == fpModelLabels.Id { + case fpModelLabels.Id: fpModelLabels.Labels = v - } else { + default: log.Warnf("Unexpected classification label id: %v, with labels: %v", k, bodyMap) } } @@ -287,7 +343,7 @@ func getThumbnail(clip_id int32, track_id int32) [][]uint16 { switch frame := t_call.Body[0].(type) { case [][]uint16: // Access row/col - log.Debugf("Thubnail (clip id: %d, track_id: %d) is: %d×%d", clip_id, track_id, len(frame), len(frame[0])) + log.Debugf("Thumbnail (clip id: %d, track_id: %d) is: %d×%d", clip_id, track_id, len(frame), len(frame[0])) return t_call.Body[0].([][]uint16) default: log.Warnf("GetThumbnail returned an unexpected 2D type: %T", frame) diff --git a/internal/tc2-hat-comms/service.go b/internal/tc2-hat-comms/service.go deleted file mode 100644 index b3fc9d4..0000000 --- a/internal/tc2-hat-comms/service.go +++ /dev/null @@ -1,101 +0,0 @@ -/* -attiny-controller - Communicates with ATtiny microcontroller -Copyright (C) 2018, The Cacophony Project - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with this program. If not, see . -*/ - -package comms - -/* -import ( - "errors" - "runtime" - "strings" - - "github.com/godbus/dbus" - "github.com/godbus/dbus/introspect" -) - -// TODO This is just using the beacon name at the moment so other things don't need to be updated. -const ( - dbusName = "org.cacophony.beacon" - dbusPath = "/org/cacophony/beacon" -) - -type service struct{} - -func startService() error { - conn, err := dbus.SystemBus() - if err != nil { - return err - } - reply, err := conn.RequestName(dbusName, dbus.NameFlagDoNotQueue) - if err != nil { - return err - } - if reply != dbus.RequestNameReplyPrimaryOwner { - return errors.New("name already taken") - } - - s := &service{} - conn.Export(s, dbusPath, dbusName) - conn.Export(genIntrospectable(s), dbusPath, "org.freedesktop.DBus.Introspectable") - return nil -} - -func genIntrospectable(v interface{}) introspect.Introspectable { - node := &introspect.Node{ - Interfaces: []introspect.Interface{{ - Name: dbusName, - Methods: introspect.Methods(v), - }}, - } - return introspect.NewIntrospectable(node) -} - -func (s service) Classification(classifications map[byte]byte) *dbus.Error { - log.Println("Got DBus message 'Classification'") - return errToDBusErr(checkClassification(classifications)) -} - -func (s service) Recording() *dbus.Error { - log.Println("Got DBus message 'Recording'") - return nil -} - -func errToDBusErr(err error) *dbus.Error { - if err == nil { - return nil - } - return &dbus.Error{ - Name: dbusName + "." + getCallerName(), - Body: []interface{}{err.Error()}, - } -} - -func getCallerName() string { - fpcs := make([]uintptr, 1) - n := runtime.Callers(3, fpcs) - if n == 0 { - return "" - } - caller := runtime.FuncForPC(fpcs[0] - 1) - if caller == nil { - return "" - } - funcNames := strings.Split(caller.Name(), ".") - return funcNames[len(funcNames)-1] -} -*/ diff --git a/internal/tc2-hat-comms/simple.go b/internal/tc2-hat-comms/simple.go index 23ab4e3..0cf8398 100644 --- a/internal/tc2-hat-comms/simple.go +++ b/internal/tc2-hat-comms/simple.go @@ -1,3 +1,5 @@ +// Output mode: outputs a simple HIGH or LOW signal over serial based on detection events. + package comms import ( diff --git a/internal/tc2-hat-comms/trap-control.go b/internal/tc2-hat-comms/trap-control.go new file mode 100644 index 0000000..19a251c --- /dev/null +++ b/internal/tc2-hat-comms/trap-control.go @@ -0,0 +1,441 @@ +// Output mode: connects to and controls a trap over serial. + +package comms + +import ( + "bytes" + "encoding/json" + "fmt" + "strconv" + "sync" + "time" + + "github.com/TheCacophonyProject/event-reporter/v3/eventclient" + "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper" + "periph.io/x/conn/v3/gpio" +) + +// processTrapControl communicates the trap enabled/disabled state by writing +// the "enable" variable over UART instead of setting a digital pin. +func processTrapControl(config *CommsConfig, eventSignals chan event) error { + trapEnabled := false + previousTrapEnabled := false + lastProtectSpeciesSighting := time.Time{} + lastTrapSpeciesSighting := time.Time{} + enablingReason := "" + disablingReason := "" + + recordingStartTime := time.Time{} + trackStartTime := time.Time{} + triggerAnimal := "" + var confidence int32 + + // Open the serial port so we can send/receive messages from the trap. + port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, config.BaudRate) + if err != nil { + return fmt.Errorf("failed to open serial port: %v", err) + } + defer port.Close() + + // Create the messenger that tracks sending/receiving messages + messenger := NewUartMessenger(port) + messenger.Start() + + for { + now := time.Now() + trapEnabled = config.TrapEnabledByDefault + + if lastProtectSpeciesSighting.Add(config.ProtectDuration).After(now) { + trapEnabled = false + } else if lastTrapSpeciesSighting.Add(config.TrapDuration).After(now) { + trapEnabled = true + } + + if trapEnabled != previousTrapEnabled { + if trapEnabled { + log.Infof("Enabling trap, reason: %s", enablingReason) + success, err := messenger.setEnable(true) + if err != nil { + return fmt.Errorf("failed to enable trap: %v", err) + } + trapEnableTime := time.Now() + log.Infof("Recording start time: %s", recordingStartTime.Format("15:04:05.999")) + log.Infof("Track start time: %s", trackStartTime.Format("15:04:05.999")) + log.Infof("TrapEnableTime: %s", trapEnableTime.Format("15:04:05.999")) // TODO, we can get better accuracy on when this actually + timeToEnableTrap := trapEnableTime.Sub(recordingStartTime).String() + log.Infof("Time to enable trap: %s", timeToEnableTrap) + + eventclient.AddEvent(eventclient.Event{ + Timestamp: time.Now(), + Type: "trapEnableCommand", + Details: map[string]any{ + "reason": enablingReason, + "recordingStartTime": recordingStartTime, + "trackStartTime": trackStartTime, + "trapEnableTime": trapEnableTime, + "timeToEnableTrap": timeToEnableTrap, + "animal": triggerAnimal, + "confidence": confidence, + "enableTrapSuccess": success, // If this fails that likely means the trap is not in a state to be enabled through the UART + }, + }) + } else { + log.Info("Disabling trap, reason: ", disablingReason) + success, err := messenger.setEnable(false) + if err != nil { + return fmt.Errorf("failed to disable trap: %v", err) + } + eventclient.AddEvent(eventclient.Event{ + Timestamp: time.Now(), + Type: "trapDisableCommand", + Details: map[string]any{ + "reason": disablingReason, + "disableTrapSuccess": success, + }, + }) + } + } + + previousTrapEnabled = trapEnabled + + var delay = 10 * time.Second + trapDisableTime := lastTrapSpeciesSighting.Add(config.TrapDuration) + if trapEnabled && time.Until(trapDisableTime) < delay { + delay = time.Until(trapDisableTime) + } + + disablingReason = "timeout" + enablingReason = "timeout" + log.Debug("Waiting") + select { + case t := <-eventSignals: + switch v := t.(type) { + case trackingEvent: + log.Debugf("Received tracking event: %+v", v) + trackStartTime = v.TrackStartTime + + protect, animal, conf := v.Species.MatchSpeciesWithConfidence(config.ProtectSpecies) + if protect { + disablingReason = fmt.Sprintf("Found an %s of confidence %d that needs to be protected", animal, conf) + log.Debug(disablingReason) + lastProtectSpeciesSighting = time.Now() + break + } + + trap, animal, conf := v.Species.MatchSpeciesWithConfidence(config.TrapSpecies) + if trap { + enablingReason = fmt.Sprintf("Found an %s of confidence %d that needs to be trapped", animal, conf) + triggerAnimal = animal + confidence = conf + log.Debug(enablingReason) + lastTrapSpeciesSighting = time.Now() + break + } + + log.Debug("No animals need to be protected or trapped, not changing trap state.") + + case recordingEvent: + log.Debugf("Received recording event: %+v", v) + if v.Recording { + recordingStartTime = v.Timestamp + } else { + recordingStartTime = time.Time{} + } + + default: + log.Debugf("Ignoring non tracking event: %+v", t) + continue + } + + case <-time.After(delay): + log.Debug("Scheduled check") + } + } +} + +// Message represents the data structure for communication with a device connected on UART. +// - ID: Identifier of the message being sent or the message being responded to. +// - Response: Indicates if the message is a response. +// - Type: Specifies the type of message (e.g., write, read, command, ACK, NACK). +// - Data: Contains the actual data payload, which varies depending on the type or response. +type Message struct { + ID int + Type string + Payload string + PayloadUnmarshaled any +} + +func (u *Message) String() string { + if u.PayloadUnmarshaled != nil { + return fmt.Sprintf("ID: %d, Type: %s, Payload: %v, PayloadUnmarshaled: %v", u.ID, u.Type, u.Payload, u.PayloadUnmarshaled) + } + return fmt.Sprintf("ID: %d, Type: %s, Payload: %v", u.ID, u.Type, u.Payload) +} + +func (m *Message) ToUARTLine() string { + if m == nil { + return "" + } + if m.PayloadUnmarshaled != nil { + marshaledPayload, err := json.Marshal(m.PayloadUnmarshaled) + if err != nil { + return "" + } + m.PayloadUnmarshaled = nil + m.Payload = string(marshaledPayload) + } + messageStr := fmt.Sprintf("<%d|%s|%s>", m.ID, m.Type, m.Payload) + return fmt.Sprintf("%s%d\n", messageStr, computeChecksum([]byte(messageStr))) +} + +func (m *Message) Response() bool { + return m.ID != 0 +} + +type Command struct { + Command string `json:"command"` + Args string `json:"args,omitempty"` +} + +type Write struct { + Var string `json:"var,omitempty"` + Val any `json:"val,omitempty"` +} + +// UartMessenger manages bidirectional communication with the RP2040 over UART. +// It holds a persistent serial port and routes incoming messages to either +// pending response waiters (matched by ID) or an unsolicited message channel. +type UartMessenger struct { + port *serialhelper.SerialPort + pendingMu sync.Mutex + pending map[int]chan *Message + nextID int + baudRate int +} + +// NewUartMessenger creates a UartMessenger using an already-open SerialPort. +func NewUartMessenger(port *serialhelper.SerialPort) *UartMessenger { + return &UartMessenger{ + port: port, + pending: make(map[int]chan *Message), + } +} + +// Start begins the background routing goroutine. Unsolicited messages from the RP2040 +// (i.e. not responses to a request we sent) are delivered to the unsolicited channel. +// Pass nil to discard unsolicited messages. +func (u *UartMessenger) Start() { + go u.routeMessages() +} + +// routeMessages reads lines from the serial port, parses them, and routes them: +// TODO: Maybe separate this for routing messages +// - Response messages are matched to a pending sendMessage call by ID. +// - If not a response then it is a notification from the trap. +func (u *UartMessenger) routeMessages() { + for line := range u.port.Lines { + // Parse the line + msg, err := ParseLine(line) + if err != nil { + log.Warnf("Failed to parse incoming message %q: %v", line, err) + continue + } + + // Check if the message was a response + if msg.Response() { + u.pendingMu.Lock() + ch, ok := u.pending[msg.ID] + if !ok && len(u.pending) == 1 { + // Fallback for RP2040 firmware that doesn't echo message IDs yet. + for _, c := range u.pending { + ch = c + ok = true + break + } + } + u.pendingMu.Unlock() + if ok { + ch <- msg + continue + } + } + + // If not a response then it is a notification from the trap. + parseMessageFromTrap(msg) + } +} + +func parseMessageFromTrap(msg *Message) { + log.Printf("Trap message: %+v", msg) + + // eventMessages maps trap message type to event type. + // For these events we will just make an event of the given type and add the payload in the details. + eventMessages := map[string]string{ + "MOTION": "trapMotion", + "ENABLED": "trapEnabled", + "DISABLED": "trapDisabled", + "SPOOL_RESET": "trapSpoolReset", + "TRIGGERED": "trapTriggered", + "RUNNING": "trapRunning", + "ERROR_CODE": "trapErrorCode", + "EXCEPTION": "trapException", + } + + // Messages that we want to trigger the events to be uploaded right away. + uploadEventsNowMessages := []string{ + "TRIGGERED", + "EXCEPTION", + "ERROR_CODE", + } + + // Handle messages that we want to make events for + if event, ok := eventMessages[msg.Type]; ok { + log.Info("Making event for: ", msg.Type) + details := map[string]any{} + if msg.Payload != "" { + // Try to unmarshal the payload, if not just use it as a string + err := json.Unmarshal([]byte(msg.Payload), &details) + if err != nil { + details["Payload"] = msg.Payload + } + } + err := eventclient.AddEvent(eventclient.Event{ + Timestamp: time.Now(), + Type: event, + Details: details, + }) + if err != nil { + log.Error("Error adding event:", err) + } + if contains(uploadEventsNowMessages, msg.Type) { + log.Info("Uploading events now") + err := eventclient.UploadEvents() + if err != nil { + log.Error("Error requesting events to be uploaded:", err) + } + } + return + } + + // Messages that we just want to make a log for, no event. + // logMessages := []string{} + // if contains(logMessages, msg.Type) { + // log.Infof("Trap message: %+v", msg) + // return + // } + + // Unknown messages + log.Warnf("Unknown trap message: %+v", msg) +} + +func contains(arr []string, item string) bool { + for _, v := range arr { + if v == item { + return true + } + } + return false +} + +// ParseLine parses a framed line of the form checksum. +func ParseLine(line []byte) (*Message, error) { + line = bytes.TrimLeft(line, "\x00") + lastIdx := bytes.LastIndexByte(line, '>') + if lastIdx < 0 || len(line) == 0 || line[0] != '<' { + return nil, fmt.Errorf("invalid frame: %q", line) + } + messageStr := line[:lastIdx+1] + checksumStr := line[lastIdx+1:] + + receivedChecksum, err := strconv.Atoi(string(checksumStr)) + if err != nil { + return nil, fmt.Errorf("invalid checksum in %q: %v", line, err) + } + if computeChecksum(messageStr) != receivedChecksum { + return nil, fmt.Errorf("checksum mismatch in %q", line) + } + + inner := messageStr[1 : len(messageStr)-1] + parts := bytes.SplitN(inner, []byte("|"), 3) + if len(parts) < 2 { + return nil, fmt.Errorf("invalid format: %q", line) + } + + id, err := strconv.Atoi(string(parts[0])) + if err != nil { + return nil, fmt.Errorf("invalid id in %q: %v", line, err) + } + + payload := "" + if len(parts) == 3 { + payload = string(parts[2]) + } + + return &Message{ + ID: id, + Type: string(parts[1]), + Payload: payload, + }, nil +} + +func computeChecksum(message []byte) int { + checksum := 0 + for _, b := range message { + checksum += int(b) + } + return checksum % 256 +} + +// sendMessage sends a request and waits for a matching response. +// It assigns a unique ID to the message for correlation. +func (u *UartMessenger) sendMessage(message Message) (*Message, error) { + u.pendingMu.Lock() + u.nextID++ + id := u.nextID + message.ID = id + ch := make(chan *Message, 1) + u.pending[id] = ch + u.pendingMu.Unlock() + + defer func() { + u.pendingMu.Lock() + delete(u.pending, id) + u.pendingMu.Unlock() + }() + + line := message.ToUARTLine() + log.Infof("Message: '%s'", line) + + if err := u.port.Write([]byte(line)); err != nil { + return nil, err + } + + select { + case response := <-ch: + log.Println("Response:", response) + return response, nil + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("timeout waiting for response to message ID %d", id) + } +} + +func (u *UartMessenger) setEnable(enable bool) (bool, error) { + message := Message{} + if enable { + message.Type = "ENABLE" + } else { + message.Type = "DISABLE" + } + response, err := u.sendMessage(message) + if err != nil { + return false, err + } + if response.Type == "NACK" { + return false, fmt.Errorf("NACK response") + } + if response.Type == "BAD_KEY" { + log.Warn("Got BAD_KEY response, was trying to set a key that doesn't exist") + return false, nil + } + return true, nil +} diff --git a/internal/tc2-hat-comms/uart.go b/internal/tc2-hat-comms/uart.go deleted file mode 100644 index 8c1bbe3..0000000 --- a/internal/tc2-hat-comms/uart.go +++ /dev/null @@ -1,310 +0,0 @@ -// This section deals with communication with peripherals over uart. - -package comms - -import ( - "bytes" - "encoding/json" - "fmt" - "strconv" - "time" - - "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper" - "github.com/TheCacophonyProject/tc2-hat-controller/tracks" - "periph.io/x/conn/v3/gpio" -) - -type UartMessenger struct { - baudRate int -} - -// TODO - -// UartMessage represents the data structure for communication with a device connected on UART. -// - ID: Identifier of the message being sent or the message being responded to. -// - Response: Indicates if the message is a response. -// - Type: Specifies the type of message (e.g., write, read, command, ACK, NACK). -// - Data: Contains the actual data payload, which varies depending on the type or response. -type UartMessage struct { - ID int `json:"id,omitempty"` - Response bool `json:"response,omitempty"` - Type string `json:"type,omitempty"` - Data interface{} `json:"data,omitempty"` -} - -type Command struct { - Command string `json:"command"` - Args string `json:"args,omitempty"` -} - -type Write struct { - Var string `json:"var,omitempty"` - Val interface{} `json:"val,omitempty"` -} - -func (u UartMessenger) sendTrapActiveState(active bool) error { - return u.sendWriteMessage("active", active) -} - -func processUart(config *CommsConfig, testClassification *TestClassification, trackingSignals chan event) error { - if testClassification != nil { - log.Println("Sending a test classification over UART") - - species := tracks.Species{ - testClassification.Animal: int32(testClassification.Confidence), - } - - classificationData := ClassificationData{ - Species: species, - Confidence: int32(testClassification.Confidence), - } - - message := UartMessage{ - Type: "classification", - Data: classificationData, - } - payload, err := json.Marshal(message) - if err != nil { - return err - } - - log.Printf("Sending payload: '%s'", payload) - - serialhelper.SerialSend(3, gpio.High, gpio.Low, time.Second, append(payload, byte('\r'), byte('\n')), config.BaudRate) - - return nil - } - - messenger := UartMessenger{ - baudRate: config.BaudRate, - } - - for { - log.Debug("Waiting") - for e := range trackingSignals { - switch v := e.(type) { - case trackingEvent: - fmt.Println("Tracking event:", v.Species) - err := messenger.processTrackingEvent(v) - if err != nil { - log.Error("Error processing tracking event:", err) - } - default: - log.Debug("Not processing event:", v) - continue - } - } - } -} - -func (u UartMessenger) processTrackingEvent(t trackingEvent) error { - log.Debugf("Found new track: %+v", t) - - species := tracks.Species{} - for k, v := range t.Species { - if v > 0 { - species[k] = v - } - } - - message := UartMessage{ - Type: "classification", - Data: ClassificationData{ - Species: species, - Confidence: t.Confidence, - }, - } - - payload, err := json.Marshal(message) - if err != nil { - return err - } - - log.Printf("Sending payload: '%s'", payload) - start := time.Now() - - serialhelper.SerialSend(3, gpio.High, gpio.Low, time.Second, append(payload, byte('\r'), byte('\n')), u.baudRate) - - log.Printf("Sent payload in %s", time.Since(start)) - - return nil -} - -type ClassificationData struct { - Species tracks.Species - Confidence int32 -} - -func (u UartMessenger) sendClassification(event trackingEvent) { - - data := map[string]interface{}{ - "species": event.Species, - "confidence": event.Confidence, - } - - jsonBytes, err := json.Marshal(data) - if err != nil { - fmt.Println("Error converting to JSON:", err) - return - } - - println(string(jsonBytes)) - //data["data"] = trackingEvent.Data - - sendMessage(UartMessage{ - Type: "classification", - Data: string(jsonBytes), - }, u.baudRate) -} - -func (u UartMessenger) sendWriteMessage(varName string, val interface{}) error { - data, err := json.Marshal(&Write{ - Var: varName, - Val: val, - }) - if err != nil { - return err - } - message := UartMessage{ - Type: "write", - Data: string(data), - } - response, err := sendMessage(message, u.baudRate) - if err != nil { - return err - } - if response.Type == "NACK" { - return fmt.Errorf("NACK response") - } - return nil -} - -func beep(baudRate int) error { - log.Println("beep") - return sendCommandMessage("beep", baudRate) -} - -func sendCommandMessage(cmd string, baudRate int) error { - data, err := json.Marshal(&Command{ - Command: cmd, - }) - if err != nil { - return err - } - message := UartMessage{ - Type: "command", - Data: string(data), - } - response, err := sendMessage(message, baudRate) - if err != nil { - return err - } - if response.Type == "NACK" { - return fmt.Errorf("NACK response") - } - return nil -} - -type Read struct { - Var string `json:"var,omitempty"` -} - -type ReadResponse struct { - Val string `json:"var,omitempty"` -} - -func sendReadMessage(varName string) (string, error) { - return "", nil - /* - data, err := json.Marshal(&Read{ - Var: varName, - }) - if err != nil { - return "", err - } - message := UartMessage{ - Type: "read", - Data: string(data), - } - response, err := sendMessage(message) - if err != nil { - return "", err - } - if response.Type == "NACK" { - return "", fmt.Errorf("NACK response") - } - readResponse := &ReadResponse{} - if err := json.Unmarshal([]byte(response.Data), readResponse); err != nil { - return "", err - } - return readResponse.Val, nil - */ - -} - -func checkPIR(oldPirVal int) (int, error) { - valStr, err := sendReadMessage("pir") - if err != nil { - return 0, err - } - newPirVal, err := strconv.Atoi(valStr) - if err != nil { - return 0, err - } - if oldPirVal != newPirVal { - //TODO Make event - log.Println("New pir value:", newPirVal) - } - return newPirVal, nil -} - -func computeChecksum(message []byte) int { - checksum := 0 - for _, b := range message { - checksum += int(b) - } - return checksum % 256 -} - -func sendMessage(cmd UartMessage, baudRate int) (*UartMessage, error) { - cmdData, err := json.Marshal(cmd) - if err != nil { - return nil, err - } - message := fmt.Sprintf("<%s|%d>", cmdData, computeChecksum(cmdData)) - - log.Println("Message: ", message) - responseData, err := serialhelper.SerialSendReceive(3, gpio.High, gpio.Low, time.Second, []byte(message), baudRate) - - if err != nil { - return nil, err - } - log.Println("Response: ", string(responseData)) - - if responseData[0] != '<' { - return nil, fmt.Errorf("response doesn't start with '<'") - } - if responseData[len(responseData)-1] != '>' { - return nil, fmt.Errorf("response doesn't end with '>'") - } - - // Extract and verify message and checksum - responseData = responseData[1 : len(responseData)-1] - parts := bytes.Split(responseData, []byte("|")) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid response format") - } - log.Println("Response:", string(parts[0])) - receivedChecksum, err := strconv.Atoi(string(parts[1])) - if err != nil { - return nil, err - } - if computeChecksum(parts[0]) != receivedChecksum { - return nil, fmt.Errorf("checksum mismatch") - } - - // Unmarshal response to a Message - responseMessage := &UartMessage{} - log.Println(string(parts[0])) - return responseMessage, json.Unmarshal(parts[0], responseMessage) -} diff --git a/internal/tc2-hat-rp2040/main.go b/internal/tc2-hat-rp2040/main.go index 3271544..f41ddd2 100644 --- a/internal/tc2-hat-rp2040/main.go +++ b/internal/tc2-hat-rp2040/main.go @@ -2,10 +2,13 @@ package rp2040 import ( "bufio" + "crypto/sha256" "errors" "fmt" + "io" "os" "os/exec" + "strings" "time" "github.com/TheCacophonyProject/event-reporter/v3/eventclient" @@ -25,13 +28,18 @@ type Args struct { ELF string `arg:"--elf" help:".elf file to program the RP2040 with."` RunPin string `arg:"--run-pin" help:"Run GPIO pin for the RP2040."` BootModePin string `arg:"--boot-mode-pin" help:"Boot mode GPIO pin for the RP2040."` + EraseFlash bool `arg:"--erase-flash" help:"Upload the program that will erase the flash on the RP2040"` logging.LogArgs } -const openOCDNotFoundMessage = `'openocd' was not found. Can be installed using apt 'sudo apt install openocd' or +const ( + openOCDNotFoundMessage = `'openocd' was not found. Can be installed using apt 'sudo apt install openocd' or following section 5.1 at https://datasheets.raspberrypi.com/pico/getting-started-with-pico.pdf. If installed using apt then use the config file '/etc/cacophony/raspberrypi-swd.cfg' as 'interface/raspberrypi-swd.cfg' is not available with the current version provided by apt.` + eraseFlashFirmware = "/etc/cacophony/rp2040-firmware-erase-flash.elf" + eraseFlashHash = "/etc/cacophony/rp2040-firmware-erase-flash.sha256" +) var defaultArgs = Args{ RunPin: "GPIO23", @@ -57,6 +65,31 @@ func procArgs(input []string) (Args, error) { return args, err } +func verifyFileHash(filePath, hashFilePath string) error { + expectedHashBytes, err := os.ReadFile(hashFilePath) + if err != nil { + return fmt.Errorf("failed to read hash file: %v", err) + } + expectedHash := strings.TrimSpace(string(expectedHashBytes)) + + f, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer f.Close() + + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return fmt.Errorf("failed to hash file: %v", err) + } + actualHash := fmt.Sprintf("%x", h.Sum(nil)) + + if actualHash != expectedHash { + return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash, actualHash) + } + return nil +} + func Run(inputArgs []string, ver string) error { version = ver args, err := procArgs(inputArgs) @@ -67,8 +100,26 @@ func Run(inputArgs []string, ver string) error { log.Printf("Running version: %s", version) - // Check if openocd is installed - if args.ELF != "" { + if args.ELF != "" && args.EraseFlash { + return fmt.Errorf("must specify either --elf or --erase-flash, not both") + } + + // Check if openocd is installed and ELF file exists + if args.ELF != "" || args.EraseFlash { + if args.ELF != "" { + if _, err := os.Stat(args.ELF); err != nil { + return fmt.Errorf("elf file not found: %v", err) + } + } + if args.EraseFlash { + if _, err := os.Stat(eraseFlashFirmware); err != nil { + return fmt.Errorf("erase flash firmware not found: %v", err) + } + if err := verifyFileHash(eraseFlashFirmware, eraseFlashHash); err != nil { + return fmt.Errorf("erase flash firmware hash check failed: %v", err) + } + log.Println("Erase flash firmware hash verified.") + } cmd := exec.Command("openocd", "--version") if err := cmd.Run(); err != nil { log.Println(openOCDNotFoundMessage) @@ -110,17 +161,21 @@ func Run(inputArgs []string, ver string) error { return err } - log.Println("RP2400 read for programming.") + log.Println("RP2040 is ready for programming.") success := true - if args.ELF == "" { + elfToFlash := args.ELF + if args.EraseFlash { + elfToFlash = eraseFlashFirmware + } + if elfToFlash == "" { log.Println("No elf program provided so assuming programming is done manually.") log.Println("Press enter when programming is done.") _, _ = bufio.NewReader(os.Stdin).ReadString('\n') } else { - log.Printf("Programming '%s' using 'openocd' file to RP2040\n", args.ELF) + log.Printf("Programming '%s' using 'openocd' to RP2040\n", elfToFlash) cmd := exec.Command("openocd", "-f", "/etc/cacophony/raspberrypi-swd.cfg", "-f", "/target/rp2040.cfg", "-c", - fmt.Sprintf("program %s verify reset exit", args.ELF)) + fmt.Sprintf("program %s verify reset exit", elfToFlash)) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { @@ -129,6 +184,13 @@ func Run(inputArgs []string, ver string) error { } } + if args.EraseFlash { + // make file to trigger a reprogram next time tc2-agent starts + if err := os.WriteFile("/etc/cacophony/program_rp2040", []byte{}, 0644); err != nil { + return err + } + } + log.Println("Releasing Run and Boot mode pins.") if err := runPin.In(gpio.Float, gpio.NoEdge); err != nil { return err @@ -137,10 +199,14 @@ func Run(inputArgs []string, ver string) error { return err } + details := map[string]any{"success": success} + if args.EraseFlash { + details["eraseFlash"] = true + } eventclient.AddEvent(eventclient.Event{ Timestamp: time.Now(), Type: "programmingRP2040", - Details: map[string]interface{}{"success": success}, + Details: details, }) if !success { return errors.New("failed to program RP2040") diff --git a/internal/tc2-hat-trap-cli/main.go b/internal/tc2-hat-trap-cli/main.go new file mode 100644 index 0000000..7bddc2f --- /dev/null +++ b/internal/tc2-hat-trap-cli/main.go @@ -0,0 +1,167 @@ +package trapcli + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" + "time" + + comms "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-comms" + + goconfig "github.com/TheCacophonyProject/go-config" + "github.com/TheCacophonyProject/go-utils/logging" + "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper" + "github.com/alexflint/go-arg" + "periph.io/x/conn/v3/gpio" +) + +var ( + version = "" + log = logging.NewLogger("info") +) + +type Args struct { + Command *Command `arg:"subcommand:command" help:"Send a command."` + Read *Read `arg:"subcommand:read" help:"Read from a variable."` + Write *Write `arg:"subcommand:write" help:"Write to a variable."` + Listen *Listen `arg:"subcommand:listen" help:"Continuously listen for messages from the RP2040."` + Message *CMDMessage `arg:"subcommand:msg" help:"Send a message to the RP2040."` + BaudRate int `arg:"--baud-rate" help:"Baud rate for UART communication."` + goconfig.ConfigArgs + logging.LogArgs +} + +type CMDMessage struct { + ID int `arg:"--id,required" help:"The ID of the message to send."` + Type string `arg:"--type,required" help:"The type of message to send."` + Payload string `arg:"--payload,required" help:"The payload of the message to send."` +} + +type Command struct { + Command string `arg:"--command,required" help:"The command to run."` +} + +type Read struct { + Variable string `arg:"--variable,required" help:"The variable to read from."` +} + +type Write struct { + Variable string `arg:"--variable,required" help:"The variable to write to."` + Value string `arg:"--value,required" help:"The value to write."` +} + +type Listen struct{} + +var defaultArgs = Args{ + BaudRate: 9600, +} + +func sendMessage(msg comms.Message, port *serialhelper.SerialPort) (*comms.Message, error) { + line := msg.ToUARTLine() + log.Println("Sending:", strings.TrimSpace(line)) + + if err := port.Write([]byte(line)); err != nil { + return nil, err + } + + select { + case line, ok := <-port.Lines: + if !ok { + return nil, fmt.Errorf("serial port closed while waiting for response") + } + return comms.ParseLine(line) + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("timeout waiting for response") + } +} + +func procArgs(input []string) (Args, error) { + args := defaultArgs + + parser, err := arg.NewParser(arg.Config{}, &args) + if err != nil { + return Args{}, err + } + err = parser.Parse(input) + if errors.Is(err, arg.ErrHelp) { + parser.WriteHelp(os.Stdout) + os.Exit(0) + } + if errors.Is(err, arg.ErrVersion) { + fmt.Println(version) + os.Exit(0) + } + return args, err +} + +func Run(inputArgs []string, ver string) error { + version = ver + args, err := procArgs(inputArgs) + if err != nil { + return fmt.Errorf("failed to parse args: %v", err) + } + log = logging.NewLogger(args.LogLevel) + log.Printf("Running version: %s", version) + + port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, args.BaudRate) + if err != nil { + return fmt.Errorf("failed to open serial port: %v", err) + } + defer port.Close() + + switch { + case args.Listen != nil: + fmt.Println("Listening for messages from RP2040 (Ctrl+C to stop)...") + for line := range port.Lines { + msg, err := comms.ParseLine(line) + if err != nil { + fmt.Printf("raw: %s\n", line) + log.Warnf("Failed to parse incoming message %q: %v", line, err) + continue + } + log.Println("Received:", msg) + } + return nil + + case args.Command != nil: + data, err := json.Marshal(map[string]string{"command": args.Command.Command}) + if err != nil { + return err + } + return respond(sendMessage(comms.Message{Type: "command", Payload: string(data)}, port)) + + case args.Read != nil: + data, err := json.Marshal(map[string]string{"var": args.Read.Variable}) + if err != nil { + return err + } + return respond(sendMessage(comms.Message{Type: "read", Payload: string(data)}, port)) + + case args.Write != nil: + data, err := json.Marshal(map[string]string{"var": args.Write.Variable, "val": args.Write.Value}) + if err != nil { + return err + } + return respond(sendMessage(comms.Message{Type: "write", Payload: string(data)}, port)) + + case args.Message != nil: + message := comms.Message{ID: args.Message.ID, Type: args.Message.Type, Payload: args.Message.Payload} + return respond(sendMessage(message, port)) + + default: + return fmt.Errorf("no subcommand given") + } +} + +func respond(response *comms.Message, err error) error { + if err != nil { + return err + } + if response.Type == "NACK" { + return fmt.Errorf("NACK response: %s", response.Payload) + } + fmt.Printf("type=%s payload=%s\n", response.Type, response.Payload) + return nil +} diff --git a/serialhelper/serialhelper.go b/serialhelper/serialhelper.go index 176493a..85033d1 100644 --- a/serialhelper/serialhelper.go +++ b/serialhelper/serialhelper.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "strings" + "sync" "syscall" "time" @@ -158,9 +159,7 @@ func SerialSendReceive(retries int, mul0, mul1 gpio.Level, wait time.Duration, d if err != nil { return nil, err } - defer ReleaseSerial(serialFile) - c := &serial.Config{Name: "/dev/serial0", Baud: baud, ReadTimeout: time.Second * 5} serialPort, err := serial.OpenPort(c) if err != nil { @@ -168,23 +167,46 @@ func SerialSendReceive(retries int, mul0, mul1 gpio.Level, wait time.Duration, d } defer serialPort.Close() + start := time.Now() + // add a newline at and of data if it is not there already + if data[len(data)-1] != '\n' { + data = append(data, '\n') + } n, err := serialPort.Write(data) if err != nil { return nil, err } + if n != len(data) { return nil, fmt.Errorf("wrote %d bytes, expected %d", n, len(data)) } - time.Sleep(time.Second) - buf := make([]byte, 256) - n, err = serialPort.Read(buf) - log.Infof("Received %d bytes", n) - log.Info("Received:", buf[:n]) - if err != nil { - return nil, err - } - return buf[:n], nil + var response []byte + var responseTime time.Time + firstBits := true + buf := make([]byte, 1) + for { + n, err = serialPort.Read(buf) + if err != nil { + return nil, err + } + if n == 0 { + continue + } + if firstBits { + responseTime = time.Now() + firstBits = false + } + if buf[0] == '\n' { + break + } + response = append(response, buf[0]) + } + log.Infof("Sent message at %s", start.Format("15:04:05.999")) + log.Infof("Received message at %s", responseTime.Format("15:04:05.999")) + log.Debugf("Received %d bytes", len(response)) + log.Debugf("Response time: %s", responseTime) + return response, nil } func SerialSend(retries int, mul0, mul1 gpio.Level, wait time.Duration, data []byte, baud int) error { @@ -222,3 +244,93 @@ func SerialSend(retries int, mul0, mul1 gpio.Level, wait time.Duration, data []b return nil } + +// SerialPort represents a persistent, open serial connection with a background line reader. +type SerialPort struct { + writeMu sync.Mutex + port *serial.Port + file *os.File + Lines chan []byte + done chan struct{} +} + +// OpenSerial opens the serial port persistently and starts a background line reader goroutine. +func OpenSerial(mul0, mul1 gpio.Level, baud int) (*SerialPort, error) { + file, err := GetSerial(3, mul0, mul1, time.Second) + if err != nil { + return nil, err + } + c := &serial.Config{Name: "/dev/serial0", Baud: baud, ReadTimeout: 100 * time.Millisecond} + port, err := serial.OpenPort(c) + if err != nil { + if rerr := ReleaseSerial(file); rerr != nil { + log.Printf("Failed to release serial: %v", rerr) + } + return nil, err + } + sp := &SerialPort{ + port: port, + file: file, + Lines: make(chan []byte, 16), + done: make(chan struct{}), + } + go sp.readLoop() + return sp, nil +} + +// readLoop continuously reads lines from the serial port and sends them to Lines. +// It exits when Close is called. +func (s *SerialPort) readLoop() { + defer close(s.Lines) + buf := make([]byte, 1) + var line []byte + for { + select { + case <-s.done: + return + default: + } + n, err := s.port.Read(buf) + if err != nil || n == 0 { + continue + } + if buf[0] == '\n' { + if len(line) > 0 { + msg := make([]byte, len(line)) + copy(msg, line) + select { + case s.Lines <- msg: + case <-s.done: + return + } + line = line[:0] + } + } else { + line = append(line, buf[0]) + } + } +} + +// Write sends data over the serial port. Appends a newline if not already present. +func (s *SerialPort) Write(data []byte) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if len(data) == 0 || data[len(data)-1] != '\n' { + data = append(data, '\n') + } + n, err := s.port.Write(data) + if err != nil { + return err + } + if n != len(data) { + return fmt.Errorf("wrote %d bytes, expected %d", n, len(data)) + } + return nil +} + +// Close stops the background reader and releases the serial port. +func (s *SerialPort) Close() error { + close(s.done) + s.port.Close() + return ReleaseSerial(s.file) +}