diff --git a/.gitignore b/.gitignore
index 1b167ff..f0b6b21 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,4 @@
/dist
_release/attiny-firmware.hex
_release/attiny-firmware.hex.sha256
-
+/tc2-hat-controller
diff --git a/.goreleaser.yml b/.goreleaser.yml
index 8909a10..b88333a 100644
--- a/.goreleaser.yml
+++ b/.goreleaser.yml
@@ -77,6 +77,8 @@ nfpms:
dst: /usr/bin/tc2-hat-rtc
- src: _release/tc2-hat-temp
dst: /usr/bin/tc2-hat-temp
+ - src: _release/tc2-hat-trap-cli
+ dst: /usr/bin/tc2-hat-trap-cli
dependencies:
#- python3-pip
diff --git a/_release/tc2-hat-trap-cli b/_release/tc2-hat-trap-cli
new file mode 100644
index 0000000..eb73af0
--- /dev/null
+++ b/_release/tc2-hat-trap-cli
@@ -0,0 +1,2 @@
+#!/bin/bash
+exec /usr/bin/tc2-hat-controller trap-cli "$@"
diff --git a/cmd/tc2-hat-controller/main.go b/cmd/tc2-hat-controller/main.go
index f495ab6..aab802a 100644
--- a/cmd/tc2-hat-controller/main.go
+++ b/cmd/tc2-hat-controller/main.go
@@ -12,6 +12,7 @@ import (
rp2040 "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-rp2040"
rtc "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-rtc"
temp "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-temp"
+ trapcli "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-trap-cli"
)
var (
@@ -60,6 +61,8 @@ func runMain() error {
err = rtc.Run(args, version)
case "temp":
err = temp.Run(args, version)
+ case "trap-cli":
+ err = trapcli.Run(args, version)
default:
err = fmt.Errorf("unknown subcommand: %s", subcommand)
}
diff --git a/internal/tc2-hat-comms/bluetooth.go b/internal/tc2-hat-comms/bluetooth.go
deleted file mode 100644
index b2d7f9d..0000000
--- a/internal/tc2-hat-comms/bluetooth.go
+++ /dev/null
@@ -1,10 +0,0 @@
-// This section deals with communication with peripherals over bluetooth.
-
-package comms
-
-/*
-func processBluetooth() error {
- // TODO
- return nil
-}
-*/
diff --git a/internal/tc2-hat-comms/json-out.go b/internal/tc2-hat-comms/json-out.go
new file mode 100644
index 0000000..7f0af79
--- /dev/null
+++ b/internal/tc2-hat-comms/json-out.go
@@ -0,0 +1,98 @@
+// Output mode: sends events out over serial in JSON format.
+
+package comms
+
+import (
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper"
+ "github.com/TheCacophonyProject/tc2-hat-controller/tracks"
+)
+
+type ClassificationData struct {
+ Species tracks.Species `json:"species"`
+ Confidence int32 `json:"confidence"`
+}
+
+type jsonOut struct {
+ Type string `json:"type,omitempty"`
+ Data ClassificationData `json:"data"`
+}
+
+func processJSONOut(config *CommsConfig, testClassification *TestClassification, trackingSignals chan event, port *serialhelper.SerialPort) error {
+ if testClassification != nil {
+ log.Println("Sending a test classification over UART")
+
+ species := tracks.Species{
+ testClassification.Animal: int32(testClassification.Confidence),
+ }
+
+ classificationData := ClassificationData{
+ Species: species,
+ Confidence: int32(testClassification.Confidence),
+ }
+
+ message := jsonOut{
+ Type: "classification",
+ Data: classificationData,
+ }
+ payload, err := json.Marshal(message)
+ if err != nil {
+ return err
+ }
+
+ log.Printf("Sending payload: '%s'", payload)
+ return port.Write(append(payload, '\r', '\n'))
+ }
+
+ for {
+ log.Debug("Waiting")
+ for e := range trackingSignals {
+ switch v := e.(type) {
+ case trackingEvent:
+ fmt.Println("Tracking event:", v.Species)
+ err := processTrackingEvent(v, port)
+ if err != nil {
+ log.Error("Error processing tracking event:", err)
+ }
+ default:
+ log.Debug("Not processing event:", v)
+ continue
+ }
+ }
+ }
+}
+
+func processTrackingEvent(t trackingEvent, port *serialhelper.SerialPort) error {
+ log.Debugf("Found new track: %+v", t)
+
+ species := tracks.Species{}
+ for k, v := range t.Species {
+ if v > 0 {
+ species[k] = v
+ }
+ }
+
+ message := jsonOut{
+ Type: "classification",
+ Data: ClassificationData{
+ Species: species,
+ Confidence: t.Confidence,
+ },
+ }
+
+ payload, err := json.Marshal(message)
+ if err != nil {
+ return err
+ }
+
+ log.Printf("Sending payload: '%s'", payload)
+ start := time.Now()
+
+ err = port.Write(append(payload, '\r', '\n'))
+
+ log.Printf("Sent payload in %s", time.Since(start))
+ return err
+}
diff --git a/internal/tc2-hat-comms/main.go b/internal/tc2-hat-comms/main.go
index 6d5a5f4..25f4dcf 100644
--- a/internal/tc2-hat-comms/main.go
+++ b/internal/tc2-hat-comms/main.go
@@ -9,10 +9,12 @@ import (
goconfig "github.com/TheCacophonyProject/go-config"
"github.com/TheCacophonyProject/go-utils/logging"
+ "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper"
"github.com/TheCacophonyProject/tc2-hat-controller/tracks"
"github.com/alexflint/go-arg"
"github.com/google/go-cmp/cmp"
"github.com/rjeczalik/notify"
+ "periph.io/x/conn/v3/gpio"
)
var (
@@ -110,11 +112,6 @@ func Run(inputArgs []string, ver string) error {
}
}
- if config.CommsOut == "uart" && config.Bluetooth {
- log.Error("Can't have output set to UART and Bluetooth enabled at the same time.")
- return fmt.Errorf("can't have output set to UART and Bluetooth enabled at the same time")
- }
-
log.Info("Species to trap:\n", tracks.Species(config.TrapSpecies))
log.Info("Species to protect:\n", tracks.Species(config.ProtectSpecies))
@@ -127,8 +124,8 @@ func Run(inputArgs []string, ver string) error {
}
switch config.CommsOut {
- case "uart":
- log.Info("Running UART output.")
+ case "uart", "json-out":
+ log.Info("Running UART/json-out.")
// uart comms channel listens for tracking events
err = addTrackingEvents(eventsChan)
@@ -136,7 +133,13 @@ func Run(inputArgs []string, ver string) error {
return err
}
- if err := processUart(config, args.SendTestClassification, eventsChan); err != nil {
+ port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, config.BaudRate)
+ if err != nil {
+ return fmt.Errorf("failed to open serial port: %v", err)
+ }
+ defer port.Close()
+
+ if err := processJSONOut(config, args.SendTestClassification, eventsChan, port); err != nil {
return err
}
case "simple":
@@ -151,182 +154,47 @@ func Run(inputArgs []string, ver string) error {
if err := processSimpleOutput(config, eventsChan); err != nil {
return err
}
- case "at-esl":
- log.Info("Running AT-ESL output.")
- config.BaudRate = 4800 // Force AT-ESL baud rate to be 4800
+ case "trap-control":
+ log.Info("Running trap-control output.")
- // at-esl comms channel listens for tracking *reprocessed* events
- err = addTrackingReprocessedEvents(eventsChan)
- if err != nil {
- return err
- }
+ // TODO, check what speed we want for this
+ log.Info("Forcing baud rate to 9600, this will likely change in the future.")
+ config.BaudRate = 9600
- if err := processATESL(config, args.SendTestClassification, eventsChan); err != nil {
+ // Add tracking events to the channel.
+ // This is so we can activate and deactivate the trap depending on the track classification.
+ err = addTrackingEvents(eventsChan)
+ if err != nil {
return err
}
- default:
- return fmt.Errorf("unknown output type '%s'", config.CommsOut)
- }
-
- return nil
- /*
-
- trapActiveUntil := time.Time{}
- trapActive := false
-
- // Initialize the periph host drivers
- if _, err := host.Init(); err != nil {
- log.Printf("Failed to initialize periph: %v\n", err)
+ // Add recording start/stop events to the channel.
+ // This is so we can see how long it takes from a recording started to getting a track classification.
+ err = addRecordingEvents(eventsChan)
+ if err != nil {
return err
}
- log.Info("Get lock on serial port")
- if config.CommsOut == "uart" || config.CommsOut == "simple" {
- serialFile, err := serialhelper.GetSerial(3, gpio.High, gpio.Low, time.Second)
- if err != nil {
- return err
- }
- defer serialhelper.ReleaseSerial(serialFile)
- }
- log.Info("Done")
-
- protectDuration := time.Minute
- trapDuration := time.Duration(args.TrapStayActiveDuration) * time.Second
-
- var newTrack *trackingEvent
- lastProtectSpeciesSighting := time.Time{}
- lastTrapSpeciesSighting := time.Time{}
-
- for {
-
- now := time.Now()
- newTrapActive :=
- (lastProtectSpeciesSighting.Add(protectDuration).Before(now) && // Nothing to protect has been seen recently.
- lastTrapSpeciesSighting.Add(trapDuration).After(now)) // And something to trap has been sighted recently.
-
- if trapActive != newTrapActive {
- trapActive = newTrapActive
-
- if trapActive {
- log.Println("Activating trap")
- } else {
- log.Println("Deactivating trap")
- }
-
- switch args.OutputType {
- case "uart":
- log.Info("Outputting trap active state via UART")
- if err := processUart(); err != nil {
- return err
- }
- // TODO
-
- case "bluetooth":
- log.Info("Outputting trap active state via bluetooth")
- if err := processBluetooth(); err != nil {
- return err
- }
- // TODO
-
- case "digital":
- log.Info("Outputting trap active state via digital signals")
- //if err := processDigital(); err != nil {
- // return err
- //}
-
- default:
- return fmt.Errorf("unhandled output type: %s", args.OutputType)
- }
- }
-
- var delay = 10 * time.Second
- if trapActive && time.Until(trapActiveUntil) < delay {
- delay = time.Until(trapActiveUntil)
- }
-
- newTrack = nil
- log.Debug("Waiting ")
- select {
- case t := <-trackingSignals:
- newTrack = &t
- log.Debugf("Found new track: %+v", newTrack)
-
- if newTrack.species.MatchSpeciesWithConfidence(protectSpecies) {
- log.Debug("Found an animal that needs to be protected, deactivating trap")
- lastProtectSpeciesSighting = time.Now()
- //trapActiveUntil = time.Time{}
- } else if newTrack.species.MatchSpeciesWithConfidence(trapSpecies) {
- log.Debug("Found an animal that needs to be trapped, activating trap")
- lastTrapSpeciesSighting = time.Now()
-
- //trapActiveUntil = time.Now().Add(time.Duration(args.TrapStayActiveDuration) * time.Second)
- } else {
- log.Debug("No animals need to be protected or trapped, not changing trap state.")
- }
-
- case <-time.After(delay):
- log.Debug("Scheduled check")
- }
- }
- */
-
- /*
- for t := range tracks {
- log.Infof("Found track: %+v", t)
- }
-
- // Start dbus to listen for classification messages.
-
- if err := beep(); err != nil {
+ // Run the trap control process
+ if err := processTrapControl(config, eventsChan); err != nil {
return err
}
+ case "at-esl":
+ log.Info("Running AT-ESL output.")
+ config.BaudRate = 4800 // Force AT-ESL baud rate to be 4800
- log.Println("Starting UART service")
- if err := startService(); err != nil {
+ // at-esl comms channel listens for tracking *reprocessed* events
+ err = addTrackingReprocessedEvents(eventsChan)
+ if err != nil {
return err
}
- trapActive = false
- if err := sendTrapActiveState(trapActive); err != nil {
+ if err := processATESL(config, args.SendTestClassification, eventsChan); err != nil {
return err
}
-
- for {
- waitUntil := time.Now().Add(5 * time.Second)
- if trapActive {
- waitUntil = activateTrapUntil
- }
-
- select {
- case <-activeTrapSig:
- case <-time.After(time.Until(waitUntil)):
- }
- trapActive = time.Now().Before(activateTrapUntil)
-
- if err := sendTrapActiveState(trapActive); err != nil {
- return err
- }
- }
- */
-}
-
-/*
-func checkClassification(data map[byte]byte) error {
- for k, v := range data {
- if k == 1 && v > 80 {
- activateTrap()
- }
- if k == 7 && v > 80 {
- activateTrap()
- }
+ default:
+ return fmt.Errorf("unknown output type '%s'", config.CommsOut)
}
- return nil
-}
-func activateTrap() {
- log.Println("Activating trap")
- activateTrapUntil = time.Now().Add(time.Minute)
- activeTrapSig <- "trap"
+ return nil
}
-*/
diff --git a/internal/tc2-hat-comms/service-monitor.go b/internal/tc2-hat-comms/service-monitor.go
index c895254..220649e 100644
--- a/internal/tc2-hat-comms/service-monitor.go
+++ b/internal/tc2-hat-comms/service-monitor.go
@@ -38,6 +38,7 @@ type trackingEvent struct {
Tracking bool
LastPredictionFrame int32
ClipAgeSeconds int32
+ TrackStartTime time.Time
}
func (t trackingEvent) isEvent() {}
@@ -48,11 +49,17 @@ type batteryEvent struct {
Percent float64
}
+// Add tracking reprocessed events to the channel
+// Tracking reprocessed events are sent once the track has finished and it gets reprocessed.
+// This is useful if you are just using the events for reporting purposes and don't need to control
+// something in real time.
func addTrackingReprocessedEvents(eventsChan chan event) error {
targetSignalName := "org.cacophony.thermalrecorder.TrackingReprocessed"
return addTrackingEventsForSignal(eventsChan, targetSignalName)
}
+// Add tracking events to the channel
+// Tracking events are sent while the track is in progress.
func addTrackingEvents(eventsChan chan event) error {
targetSignalName := "org.cacophony.thermalrecorder.Tracking"
return addTrackingEventsForSignal(eventsChan, targetSignalName)
@@ -91,7 +98,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string)
log.Debugf("Received tracking event [%v]:", signal.Name)
// Reprocessed signals have an additional parameter 'clip_end_time'
- if len(signal.Body) < 12 {
+ if len(signal.Body) != 13 {
log.Errorf("Unexpected signal format in body: %v", signal.Body)
continue
}
@@ -107,9 +114,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string)
log.Debugf("Tracking: %v", signal.Body[9])
log.Debugf("Last prediction frame: %v", signal.Body[10])
log.Debugf("Model Id: %v", signal.Body[11])
- if len(signal.Body) >= 13 {
- log.Debugf("Clip End Time: %v", signal.Body[12])
- }
+ log.Debugf("Track Start Time: %v", signal.Body[12])
var modelId int32
var modelLabels []string
@@ -153,16 +158,8 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string)
var region [4]int32
copy(region[:], signal.Body[5].([]int32))
- // See if we have a clip end time
- clipAgeSeconds := int32(0)
- if len(signal.Body) >= 13 {
- ts := signal.Body[12].(float64)
- now := time.Now()
- target := time.Unix(int64(ts), int64((ts-float64(int64(ts)))*1e9))
-
- clipAgeSeconds = int32(now.Sub(target).Seconds())
- log.Debugf("Clip is %d seconds old", clipAgeSeconds)
- }
+ nanoSeconds := signal.Body[12].(int64) * 1e6
+ trackStartTime := time.Unix(0, nanoSeconds)
// Finally let's build our tracking event
t := trackingEvent{
@@ -177,7 +174,7 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string)
BlankRegion: signal.Body[8].(bool),
Tracking: signal.Body[9].(bool),
LastPredictionFrame: signal.Body[10].(int32),
- ClipAgeSeconds: clipAgeSeconds,
+ TrackStartTime: trackStartTime,
}
log.Debugf("Sending tracking event: %+v", t)
@@ -189,6 +186,64 @@ func addTrackingEventsForSignal(eventsChan chan event, targetSignalName string)
return nil
}
+// recordingEvent is a event that is made at the start and end of a recording
+type recordingEvent struct {
+ event
+ Timestamp time.Time
+ Recording bool
+}
+
+// Add recording events to the channel
+// These are events that are made at the start and end of a recording
+func addRecordingEvents(eventsChan chan event) error {
+ // Listen for signals
+ targetSignalName := "org.cacophony.thermalrecorder.Recording"
+
+ // Connect to the system bus
+ conn, err := dbus.SystemBus()
+ if err != nil {
+ log.Fatalf("Failed to connect to system bus: %v", err)
+ }
+
+ // Add a match rule to listen for our dbus signals
+ rule := "type='signal',interface='org.cacophony.thermalrecorder'"
+ call := conn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, rule)
+ if call.Err != nil {
+ log.Fatalf("Failed to add match rule: %v", call.Err)
+ }
+
+ // Create a channel to receive signals
+ c := make(chan *dbus.Signal, 10)
+ conn.Signal(c)
+
+ log.Infof("Listening for D-Bus signals: %s", targetSignalName)
+ // Listen for signals, process and send tracking events to the channel.
+ go func() {
+ for signal := range c {
+ if signal.Name == targetSignalName {
+ log.Debug("Received Recording event.")
+ if len(signal.Body) != 2 {
+ log.Errorf("Unexpected signal format in body: %v", signal.Body)
+ continue
+ }
+
+ // Time is given in ms since epoch, we will convert to nanoseconds so we can use the time package.
+ nanoSeconds := signal.Body[0].(int64) * 1e6
+ recordingStartTime := time.Unix(0, nanoSeconds)
+
+ // Send the event to the channel.
+ eventsChan <- recordingEvent{
+ Timestamp: recordingStartTime,
+ Recording: signal.Body[1].(bool),
+ }
+ }
+ }
+ }()
+
+ return nil
+}
+
+// Add battery events to the channel.
func addBatteryEvents(eventsChan chan event) error {
// Listen for signals
targetSignalName := "org.cacophony.attiny.Battery"
@@ -255,14 +310,15 @@ func getLabels() {
}
bodyMap := t_call.Body[0].(map[int32][]string)
- // Out model labels have id '1' .. false-postitives are the other element.
+ // Out model labels have id '1' .. false-positive are the other element.
// e.g. [map[1:[bird cat deer ... vehicle wallaby] 1004:[animal false-positive]]]
for k, v := range bodyMap {
- if k == animalsList.Id {
+ switch k {
+ case animalsList.Id:
animalsList.Labels = v
- } else if k == fpModelLabels.Id {
+ case fpModelLabels.Id:
fpModelLabels.Labels = v
- } else {
+ default:
log.Warnf("Unexpected classification label id: %v, with labels: %v", k, bodyMap)
}
}
@@ -287,7 +343,7 @@ func getThumbnail(clip_id int32, track_id int32) [][]uint16 {
switch frame := t_call.Body[0].(type) {
case [][]uint16:
// Access row/col
- log.Debugf("Thubnail (clip id: %d, track_id: %d) is: %d×%d", clip_id, track_id, len(frame), len(frame[0]))
+ log.Debugf("Thumbnail (clip id: %d, track_id: %d) is: %d×%d", clip_id, track_id, len(frame), len(frame[0]))
return t_call.Body[0].([][]uint16)
default:
log.Warnf("GetThumbnail returned an unexpected 2D type: %T", frame)
diff --git a/internal/tc2-hat-comms/service.go b/internal/tc2-hat-comms/service.go
deleted file mode 100644
index b3fc9d4..0000000
--- a/internal/tc2-hat-comms/service.go
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
-attiny-controller - Communicates with ATtiny microcontroller
-Copyright (C) 2018, The Cacophony Project
-
-This program is free software: you can redistribute it and/or modify
-it under the terms of the GNU General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
-
-This program is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-GNU General Public License for more details.
-
-You should have received a copy of the GNU General Public License
-along with this program. If not, see .
-*/
-
-package comms
-
-/*
-import (
- "errors"
- "runtime"
- "strings"
-
- "github.com/godbus/dbus"
- "github.com/godbus/dbus/introspect"
-)
-
-// TODO This is just using the beacon name at the moment so other things don't need to be updated.
-const (
- dbusName = "org.cacophony.beacon"
- dbusPath = "/org/cacophony/beacon"
-)
-
-type service struct{}
-
-func startService() error {
- conn, err := dbus.SystemBus()
- if err != nil {
- return err
- }
- reply, err := conn.RequestName(dbusName, dbus.NameFlagDoNotQueue)
- if err != nil {
- return err
- }
- if reply != dbus.RequestNameReplyPrimaryOwner {
- return errors.New("name already taken")
- }
-
- s := &service{}
- conn.Export(s, dbusPath, dbusName)
- conn.Export(genIntrospectable(s), dbusPath, "org.freedesktop.DBus.Introspectable")
- return nil
-}
-
-func genIntrospectable(v interface{}) introspect.Introspectable {
- node := &introspect.Node{
- Interfaces: []introspect.Interface{{
- Name: dbusName,
- Methods: introspect.Methods(v),
- }},
- }
- return introspect.NewIntrospectable(node)
-}
-
-func (s service) Classification(classifications map[byte]byte) *dbus.Error {
- log.Println("Got DBus message 'Classification'")
- return errToDBusErr(checkClassification(classifications))
-}
-
-func (s service) Recording() *dbus.Error {
- log.Println("Got DBus message 'Recording'")
- return nil
-}
-
-func errToDBusErr(err error) *dbus.Error {
- if err == nil {
- return nil
- }
- return &dbus.Error{
- Name: dbusName + "." + getCallerName(),
- Body: []interface{}{err.Error()},
- }
-}
-
-func getCallerName() string {
- fpcs := make([]uintptr, 1)
- n := runtime.Callers(3, fpcs)
- if n == 0 {
- return ""
- }
- caller := runtime.FuncForPC(fpcs[0] - 1)
- if caller == nil {
- return ""
- }
- funcNames := strings.Split(caller.Name(), ".")
- return funcNames[len(funcNames)-1]
-}
-*/
diff --git a/internal/tc2-hat-comms/simple.go b/internal/tc2-hat-comms/simple.go
index 23ab4e3..0cf8398 100644
--- a/internal/tc2-hat-comms/simple.go
+++ b/internal/tc2-hat-comms/simple.go
@@ -1,3 +1,5 @@
+// Output mode: outputs a simple HIGH or LOW signal over serial based on detection events.
+
package comms
import (
diff --git a/internal/tc2-hat-comms/trap-control.go b/internal/tc2-hat-comms/trap-control.go
new file mode 100644
index 0000000..19a251c
--- /dev/null
+++ b/internal/tc2-hat-comms/trap-control.go
@@ -0,0 +1,441 @@
+// Output mode: connects to and controls a trap over serial.
+
+package comms
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/TheCacophonyProject/event-reporter/v3/eventclient"
+ "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper"
+ "periph.io/x/conn/v3/gpio"
+)
+
+// processTrapControl communicates the trap enabled/disabled state by writing
+// the "enable" variable over UART instead of setting a digital pin.
+func processTrapControl(config *CommsConfig, eventSignals chan event) error {
+ trapEnabled := false
+ previousTrapEnabled := false
+ lastProtectSpeciesSighting := time.Time{}
+ lastTrapSpeciesSighting := time.Time{}
+ enablingReason := ""
+ disablingReason := ""
+
+ recordingStartTime := time.Time{}
+ trackStartTime := time.Time{}
+ triggerAnimal := ""
+ var confidence int32
+
+ // Open the serial port so we can send/receive messages from the trap.
+ port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, config.BaudRate)
+ if err != nil {
+ return fmt.Errorf("failed to open serial port: %v", err)
+ }
+ defer port.Close()
+
+ // Create the messenger that tracks sending/receiving messages
+ messenger := NewUartMessenger(port)
+ messenger.Start()
+
+ for {
+ now := time.Now()
+ trapEnabled = config.TrapEnabledByDefault
+
+ if lastProtectSpeciesSighting.Add(config.ProtectDuration).After(now) {
+ trapEnabled = false
+ } else if lastTrapSpeciesSighting.Add(config.TrapDuration).After(now) {
+ trapEnabled = true
+ }
+
+ if trapEnabled != previousTrapEnabled {
+ if trapEnabled {
+ log.Infof("Enabling trap, reason: %s", enablingReason)
+ success, err := messenger.setEnable(true)
+ if err != nil {
+ return fmt.Errorf("failed to enable trap: %v", err)
+ }
+ trapEnableTime := time.Now()
+ log.Infof("Recording start time: %s", recordingStartTime.Format("15:04:05.999"))
+ log.Infof("Track start time: %s", trackStartTime.Format("15:04:05.999"))
+ log.Infof("TrapEnableTime: %s", trapEnableTime.Format("15:04:05.999")) // TODO, we can get better accuracy on when this actually
+ timeToEnableTrap := trapEnableTime.Sub(recordingStartTime).String()
+ log.Infof("Time to enable trap: %s", timeToEnableTrap)
+
+ eventclient.AddEvent(eventclient.Event{
+ Timestamp: time.Now(),
+ Type: "trapEnableCommand",
+ Details: map[string]any{
+ "reason": enablingReason,
+ "recordingStartTime": recordingStartTime,
+ "trackStartTime": trackStartTime,
+ "trapEnableTime": trapEnableTime,
+ "timeToEnableTrap": timeToEnableTrap,
+ "animal": triggerAnimal,
+ "confidence": confidence,
+ "enableTrapSuccess": success, // If this fails that likely means the trap is not in a state to be enabled through the UART
+ },
+ })
+ } else {
+ log.Info("Disabling trap, reason: ", disablingReason)
+ success, err := messenger.setEnable(false)
+ if err != nil {
+ return fmt.Errorf("failed to disable trap: %v", err)
+ }
+ eventclient.AddEvent(eventclient.Event{
+ Timestamp: time.Now(),
+ Type: "trapDisableCommand",
+ Details: map[string]any{
+ "reason": disablingReason,
+ "disableTrapSuccess": success,
+ },
+ })
+ }
+ }
+
+ previousTrapEnabled = trapEnabled
+
+ var delay = 10 * time.Second
+ trapDisableTime := lastTrapSpeciesSighting.Add(config.TrapDuration)
+ if trapEnabled && time.Until(trapDisableTime) < delay {
+ delay = time.Until(trapDisableTime)
+ }
+
+ disablingReason = "timeout"
+ enablingReason = "timeout"
+ log.Debug("Waiting")
+ select {
+ case t := <-eventSignals:
+ switch v := t.(type) {
+ case trackingEvent:
+ log.Debugf("Received tracking event: %+v", v)
+ trackStartTime = v.TrackStartTime
+
+ protect, animal, conf := v.Species.MatchSpeciesWithConfidence(config.ProtectSpecies)
+ if protect {
+ disablingReason = fmt.Sprintf("Found an %s of confidence %d that needs to be protected", animal, conf)
+ log.Debug(disablingReason)
+ lastProtectSpeciesSighting = time.Now()
+ break
+ }
+
+ trap, animal, conf := v.Species.MatchSpeciesWithConfidence(config.TrapSpecies)
+ if trap {
+ enablingReason = fmt.Sprintf("Found an %s of confidence %d that needs to be trapped", animal, conf)
+ triggerAnimal = animal
+ confidence = conf
+ log.Debug(enablingReason)
+ lastTrapSpeciesSighting = time.Now()
+ break
+ }
+
+ log.Debug("No animals need to be protected or trapped, not changing trap state.")
+
+ case recordingEvent:
+ log.Debugf("Received recording event: %+v", v)
+ if v.Recording {
+ recordingStartTime = v.Timestamp
+ } else {
+ recordingStartTime = time.Time{}
+ }
+
+ default:
+ log.Debugf("Ignoring non tracking event: %+v", t)
+ continue
+ }
+
+ case <-time.After(delay):
+ log.Debug("Scheduled check")
+ }
+ }
+}
+
+// Message represents the data structure for communication with a device connected on UART.
+// - ID: Identifier of the message being sent or the message being responded to.
+// - Response: Indicates if the message is a response.
+// - Type: Specifies the type of message (e.g., write, read, command, ACK, NACK).
+// - Data: Contains the actual data payload, which varies depending on the type or response.
+type Message struct {
+ ID int
+ Type string
+ Payload string
+ PayloadUnmarshaled any
+}
+
+func (u *Message) String() string {
+ if u.PayloadUnmarshaled != nil {
+ return fmt.Sprintf("ID: %d, Type: %s, Payload: %v, PayloadUnmarshaled: %v", u.ID, u.Type, u.Payload, u.PayloadUnmarshaled)
+ }
+ return fmt.Sprintf("ID: %d, Type: %s, Payload: %v", u.ID, u.Type, u.Payload)
+}
+
+func (m *Message) ToUARTLine() string {
+ if m == nil {
+ return ""
+ }
+ if m.PayloadUnmarshaled != nil {
+ marshaledPayload, err := json.Marshal(m.PayloadUnmarshaled)
+ if err != nil {
+ return ""
+ }
+ m.PayloadUnmarshaled = nil
+ m.Payload = string(marshaledPayload)
+ }
+ messageStr := fmt.Sprintf("<%d|%s|%s>", m.ID, m.Type, m.Payload)
+ return fmt.Sprintf("%s%d\n", messageStr, computeChecksum([]byte(messageStr)))
+}
+
+func (m *Message) Response() bool {
+ return m.ID != 0
+}
+
+type Command struct {
+ Command string `json:"command"`
+ Args string `json:"args,omitempty"`
+}
+
+type Write struct {
+ Var string `json:"var,omitempty"`
+ Val any `json:"val,omitempty"`
+}
+
+// UartMessenger manages bidirectional communication with the RP2040 over UART.
+// It holds a persistent serial port and routes incoming messages to either
+// pending response waiters (matched by ID) or an unsolicited message channel.
+type UartMessenger struct {
+ port *serialhelper.SerialPort
+ pendingMu sync.Mutex
+ pending map[int]chan *Message
+ nextID int
+ baudRate int
+}
+
+// NewUartMessenger creates a UartMessenger using an already-open SerialPort.
+func NewUartMessenger(port *serialhelper.SerialPort) *UartMessenger {
+ return &UartMessenger{
+ port: port,
+ pending: make(map[int]chan *Message),
+ }
+}
+
+// Start begins the background routing goroutine. Unsolicited messages from the RP2040
+// (i.e. not responses to a request we sent) are delivered to the unsolicited channel.
+// Pass nil to discard unsolicited messages.
+func (u *UartMessenger) Start() {
+ go u.routeMessages()
+}
+
+// routeMessages reads lines from the serial port, parses them, and routes them:
+// TODO: Maybe separate this for routing messages
+// - Response messages are matched to a pending sendMessage call by ID.
+// - If not a response then it is a notification from the trap.
+func (u *UartMessenger) routeMessages() {
+ for line := range u.port.Lines {
+ // Parse the line
+ msg, err := ParseLine(line)
+ if err != nil {
+ log.Warnf("Failed to parse incoming message %q: %v", line, err)
+ continue
+ }
+
+ // Check if the message was a response
+ if msg.Response() {
+ u.pendingMu.Lock()
+ ch, ok := u.pending[msg.ID]
+ if !ok && len(u.pending) == 1 {
+ // Fallback for RP2040 firmware that doesn't echo message IDs yet.
+ for _, c := range u.pending {
+ ch = c
+ ok = true
+ break
+ }
+ }
+ u.pendingMu.Unlock()
+ if ok {
+ ch <- msg
+ continue
+ }
+ }
+
+ // If not a response then it is a notification from the trap.
+ parseMessageFromTrap(msg)
+ }
+}
+
+func parseMessageFromTrap(msg *Message) {
+ log.Printf("Trap message: %+v", msg)
+
+ // eventMessages maps trap message type to event type.
+ // For these events we will just make an event of the given type and add the payload in the details.
+ eventMessages := map[string]string{
+ "MOTION": "trapMotion",
+ "ENABLED": "trapEnabled",
+ "DISABLED": "trapDisabled",
+ "SPOOL_RESET": "trapSpoolReset",
+ "TRIGGERED": "trapTriggered",
+ "RUNNING": "trapRunning",
+ "ERROR_CODE": "trapErrorCode",
+ "EXCEPTION": "trapException",
+ }
+
+ // Messages that we want to trigger the events to be uploaded right away.
+ uploadEventsNowMessages := []string{
+ "TRIGGERED",
+ "EXCEPTION",
+ "ERROR_CODE",
+ }
+
+ // Handle messages that we want to make events for
+ if event, ok := eventMessages[msg.Type]; ok {
+ log.Info("Making event for: ", msg.Type)
+ details := map[string]any{}
+ if msg.Payload != "" {
+ // Try to unmarshal the payload, if not just use it as a string
+ err := json.Unmarshal([]byte(msg.Payload), &details)
+ if err != nil {
+ details["Payload"] = msg.Payload
+ }
+ }
+ err := eventclient.AddEvent(eventclient.Event{
+ Timestamp: time.Now(),
+ Type: event,
+ Details: details,
+ })
+ if err != nil {
+ log.Error("Error adding event:", err)
+ }
+ if contains(uploadEventsNowMessages, msg.Type) {
+ log.Info("Uploading events now")
+ err := eventclient.UploadEvents()
+ if err != nil {
+ log.Error("Error requesting events to be uploaded:", err)
+ }
+ }
+ return
+ }
+
+ // Messages that we just want to make a log for, no event.
+ // logMessages := []string{}
+ // if contains(logMessages, msg.Type) {
+ // log.Infof("Trap message: %+v", msg)
+ // return
+ // }
+
+ // Unknown messages
+ log.Warnf("Unknown trap message: %+v", msg)
+}
+
+func contains(arr []string, item string) bool {
+ for _, v := range arr {
+ if v == item {
+ return true
+ }
+ }
+ return false
+}
+
+// ParseLine parses a framed line of the form checksum.
+func ParseLine(line []byte) (*Message, error) {
+ line = bytes.TrimLeft(line, "\x00")
+ lastIdx := bytes.LastIndexByte(line, '>')
+ if lastIdx < 0 || len(line) == 0 || line[0] != '<' {
+ return nil, fmt.Errorf("invalid frame: %q", line)
+ }
+ messageStr := line[:lastIdx+1]
+ checksumStr := line[lastIdx+1:]
+
+ receivedChecksum, err := strconv.Atoi(string(checksumStr))
+ if err != nil {
+ return nil, fmt.Errorf("invalid checksum in %q: %v", line, err)
+ }
+ if computeChecksum(messageStr) != receivedChecksum {
+ return nil, fmt.Errorf("checksum mismatch in %q", line)
+ }
+
+ inner := messageStr[1 : len(messageStr)-1]
+ parts := bytes.SplitN(inner, []byte("|"), 3)
+ if len(parts) < 2 {
+ return nil, fmt.Errorf("invalid format: %q", line)
+ }
+
+ id, err := strconv.Atoi(string(parts[0]))
+ if err != nil {
+ return nil, fmt.Errorf("invalid id in %q: %v", line, err)
+ }
+
+ payload := ""
+ if len(parts) == 3 {
+ payload = string(parts[2])
+ }
+
+ return &Message{
+ ID: id,
+ Type: string(parts[1]),
+ Payload: payload,
+ }, nil
+}
+
+func computeChecksum(message []byte) int {
+ checksum := 0
+ for _, b := range message {
+ checksum += int(b)
+ }
+ return checksum % 256
+}
+
+// sendMessage sends a request and waits for a matching response.
+// It assigns a unique ID to the message for correlation.
+func (u *UartMessenger) sendMessage(message Message) (*Message, error) {
+ u.pendingMu.Lock()
+ u.nextID++
+ id := u.nextID
+ message.ID = id
+ ch := make(chan *Message, 1)
+ u.pending[id] = ch
+ u.pendingMu.Unlock()
+
+ defer func() {
+ u.pendingMu.Lock()
+ delete(u.pending, id)
+ u.pendingMu.Unlock()
+ }()
+
+ line := message.ToUARTLine()
+ log.Infof("Message: '%s'", line)
+
+ if err := u.port.Write([]byte(line)); err != nil {
+ return nil, err
+ }
+
+ select {
+ case response := <-ch:
+ log.Println("Response:", response)
+ return response, nil
+ case <-time.After(5 * time.Second):
+ return nil, fmt.Errorf("timeout waiting for response to message ID %d", id)
+ }
+}
+
+func (u *UartMessenger) setEnable(enable bool) (bool, error) {
+ message := Message{}
+ if enable {
+ message.Type = "ENABLE"
+ } else {
+ message.Type = "DISABLE"
+ }
+ response, err := u.sendMessage(message)
+ if err != nil {
+ return false, err
+ }
+ if response.Type == "NACK" {
+ return false, fmt.Errorf("NACK response")
+ }
+ if response.Type == "BAD_KEY" {
+ log.Warn("Got BAD_KEY response, was trying to set a key that doesn't exist")
+ return false, nil
+ }
+ return true, nil
+}
diff --git a/internal/tc2-hat-comms/uart.go b/internal/tc2-hat-comms/uart.go
deleted file mode 100644
index 8c1bbe3..0000000
--- a/internal/tc2-hat-comms/uart.go
+++ /dev/null
@@ -1,310 +0,0 @@
-// This section deals with communication with peripherals over uart.
-
-package comms
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "strconv"
- "time"
-
- "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper"
- "github.com/TheCacophonyProject/tc2-hat-controller/tracks"
- "periph.io/x/conn/v3/gpio"
-)
-
-type UartMessenger struct {
- baudRate int
-}
-
-// TODO
-
-// UartMessage represents the data structure for communication with a device connected on UART.
-// - ID: Identifier of the message being sent or the message being responded to.
-// - Response: Indicates if the message is a response.
-// - Type: Specifies the type of message (e.g., write, read, command, ACK, NACK).
-// - Data: Contains the actual data payload, which varies depending on the type or response.
-type UartMessage struct {
- ID int `json:"id,omitempty"`
- Response bool `json:"response,omitempty"`
- Type string `json:"type,omitempty"`
- Data interface{} `json:"data,omitempty"`
-}
-
-type Command struct {
- Command string `json:"command"`
- Args string `json:"args,omitempty"`
-}
-
-type Write struct {
- Var string `json:"var,omitempty"`
- Val interface{} `json:"val,omitempty"`
-}
-
-func (u UartMessenger) sendTrapActiveState(active bool) error {
- return u.sendWriteMessage("active", active)
-}
-
-func processUart(config *CommsConfig, testClassification *TestClassification, trackingSignals chan event) error {
- if testClassification != nil {
- log.Println("Sending a test classification over UART")
-
- species := tracks.Species{
- testClassification.Animal: int32(testClassification.Confidence),
- }
-
- classificationData := ClassificationData{
- Species: species,
- Confidence: int32(testClassification.Confidence),
- }
-
- message := UartMessage{
- Type: "classification",
- Data: classificationData,
- }
- payload, err := json.Marshal(message)
- if err != nil {
- return err
- }
-
- log.Printf("Sending payload: '%s'", payload)
-
- serialhelper.SerialSend(3, gpio.High, gpio.Low, time.Second, append(payload, byte('\r'), byte('\n')), config.BaudRate)
-
- return nil
- }
-
- messenger := UartMessenger{
- baudRate: config.BaudRate,
- }
-
- for {
- log.Debug("Waiting")
- for e := range trackingSignals {
- switch v := e.(type) {
- case trackingEvent:
- fmt.Println("Tracking event:", v.Species)
- err := messenger.processTrackingEvent(v)
- if err != nil {
- log.Error("Error processing tracking event:", err)
- }
- default:
- log.Debug("Not processing event:", v)
- continue
- }
- }
- }
-}
-
-func (u UartMessenger) processTrackingEvent(t trackingEvent) error {
- log.Debugf("Found new track: %+v", t)
-
- species := tracks.Species{}
- for k, v := range t.Species {
- if v > 0 {
- species[k] = v
- }
- }
-
- message := UartMessage{
- Type: "classification",
- Data: ClassificationData{
- Species: species,
- Confidence: t.Confidence,
- },
- }
-
- payload, err := json.Marshal(message)
- if err != nil {
- return err
- }
-
- log.Printf("Sending payload: '%s'", payload)
- start := time.Now()
-
- serialhelper.SerialSend(3, gpio.High, gpio.Low, time.Second, append(payload, byte('\r'), byte('\n')), u.baudRate)
-
- log.Printf("Sent payload in %s", time.Since(start))
-
- return nil
-}
-
-type ClassificationData struct {
- Species tracks.Species
- Confidence int32
-}
-
-func (u UartMessenger) sendClassification(event trackingEvent) {
-
- data := map[string]interface{}{
- "species": event.Species,
- "confidence": event.Confidence,
- }
-
- jsonBytes, err := json.Marshal(data)
- if err != nil {
- fmt.Println("Error converting to JSON:", err)
- return
- }
-
- println(string(jsonBytes))
- //data["data"] = trackingEvent.Data
-
- sendMessage(UartMessage{
- Type: "classification",
- Data: string(jsonBytes),
- }, u.baudRate)
-}
-
-func (u UartMessenger) sendWriteMessage(varName string, val interface{}) error {
- data, err := json.Marshal(&Write{
- Var: varName,
- Val: val,
- })
- if err != nil {
- return err
- }
- message := UartMessage{
- Type: "write",
- Data: string(data),
- }
- response, err := sendMessage(message, u.baudRate)
- if err != nil {
- return err
- }
- if response.Type == "NACK" {
- return fmt.Errorf("NACK response")
- }
- return nil
-}
-
-func beep(baudRate int) error {
- log.Println("beep")
- return sendCommandMessage("beep", baudRate)
-}
-
-func sendCommandMessage(cmd string, baudRate int) error {
- data, err := json.Marshal(&Command{
- Command: cmd,
- })
- if err != nil {
- return err
- }
- message := UartMessage{
- Type: "command",
- Data: string(data),
- }
- response, err := sendMessage(message, baudRate)
- if err != nil {
- return err
- }
- if response.Type == "NACK" {
- return fmt.Errorf("NACK response")
- }
- return nil
-}
-
-type Read struct {
- Var string `json:"var,omitempty"`
-}
-
-type ReadResponse struct {
- Val string `json:"var,omitempty"`
-}
-
-func sendReadMessage(varName string) (string, error) {
- return "", nil
- /*
- data, err := json.Marshal(&Read{
- Var: varName,
- })
- if err != nil {
- return "", err
- }
- message := UartMessage{
- Type: "read",
- Data: string(data),
- }
- response, err := sendMessage(message)
- if err != nil {
- return "", err
- }
- if response.Type == "NACK" {
- return "", fmt.Errorf("NACK response")
- }
- readResponse := &ReadResponse{}
- if err := json.Unmarshal([]byte(response.Data), readResponse); err != nil {
- return "", err
- }
- return readResponse.Val, nil
- */
-
-}
-
-func checkPIR(oldPirVal int) (int, error) {
- valStr, err := sendReadMessage("pir")
- if err != nil {
- return 0, err
- }
- newPirVal, err := strconv.Atoi(valStr)
- if err != nil {
- return 0, err
- }
- if oldPirVal != newPirVal {
- //TODO Make event
- log.Println("New pir value:", newPirVal)
- }
- return newPirVal, nil
-}
-
-func computeChecksum(message []byte) int {
- checksum := 0
- for _, b := range message {
- checksum += int(b)
- }
- return checksum % 256
-}
-
-func sendMessage(cmd UartMessage, baudRate int) (*UartMessage, error) {
- cmdData, err := json.Marshal(cmd)
- if err != nil {
- return nil, err
- }
- message := fmt.Sprintf("<%s|%d>", cmdData, computeChecksum(cmdData))
-
- log.Println("Message: ", message)
- responseData, err := serialhelper.SerialSendReceive(3, gpio.High, gpio.Low, time.Second, []byte(message), baudRate)
-
- if err != nil {
- return nil, err
- }
- log.Println("Response: ", string(responseData))
-
- if responseData[0] != '<' {
- return nil, fmt.Errorf("response doesn't start with '<'")
- }
- if responseData[len(responseData)-1] != '>' {
- return nil, fmt.Errorf("response doesn't end with '>'")
- }
-
- // Extract and verify message and checksum
- responseData = responseData[1 : len(responseData)-1]
- parts := bytes.Split(responseData, []byte("|"))
- if len(parts) != 2 {
- return nil, fmt.Errorf("invalid response format")
- }
- log.Println("Response:", string(parts[0]))
- receivedChecksum, err := strconv.Atoi(string(parts[1]))
- if err != nil {
- return nil, err
- }
- if computeChecksum(parts[0]) != receivedChecksum {
- return nil, fmt.Errorf("checksum mismatch")
- }
-
- // Unmarshal response to a Message
- responseMessage := &UartMessage{}
- log.Println(string(parts[0]))
- return responseMessage, json.Unmarshal(parts[0], responseMessage)
-}
diff --git a/internal/tc2-hat-rp2040/main.go b/internal/tc2-hat-rp2040/main.go
index 3271544..f41ddd2 100644
--- a/internal/tc2-hat-rp2040/main.go
+++ b/internal/tc2-hat-rp2040/main.go
@@ -2,10 +2,13 @@ package rp2040
import (
"bufio"
+ "crypto/sha256"
"errors"
"fmt"
+ "io"
"os"
"os/exec"
+ "strings"
"time"
"github.com/TheCacophonyProject/event-reporter/v3/eventclient"
@@ -25,13 +28,18 @@ type Args struct {
ELF string `arg:"--elf" help:".elf file to program the RP2040 with."`
RunPin string `arg:"--run-pin" help:"Run GPIO pin for the RP2040."`
BootModePin string `arg:"--boot-mode-pin" help:"Boot mode GPIO pin for the RP2040."`
+ EraseFlash bool `arg:"--erase-flash" help:"Upload the program that will erase the flash on the RP2040"`
logging.LogArgs
}
-const openOCDNotFoundMessage = `'openocd' was not found. Can be installed using apt 'sudo apt install openocd' or
+const (
+ openOCDNotFoundMessage = `'openocd' was not found. Can be installed using apt 'sudo apt install openocd' or
following section 5.1 at https://datasheets.raspberrypi.com/pico/getting-started-with-pico.pdf.
If installed using apt then use the config file '/etc/cacophony/raspberrypi-swd.cfg' as
'interface/raspberrypi-swd.cfg' is not available with the current version provided by apt.`
+ eraseFlashFirmware = "/etc/cacophony/rp2040-firmware-erase-flash.elf"
+ eraseFlashHash = "/etc/cacophony/rp2040-firmware-erase-flash.sha256"
+)
var defaultArgs = Args{
RunPin: "GPIO23",
@@ -57,6 +65,31 @@ func procArgs(input []string) (Args, error) {
return args, err
}
+func verifyFileHash(filePath, hashFilePath string) error {
+ expectedHashBytes, err := os.ReadFile(hashFilePath)
+ if err != nil {
+ return fmt.Errorf("failed to read hash file: %v", err)
+ }
+ expectedHash := strings.TrimSpace(string(expectedHashBytes))
+
+ f, err := os.Open(filePath)
+ if err != nil {
+ return fmt.Errorf("failed to open file: %v", err)
+ }
+ defer f.Close()
+
+ h := sha256.New()
+ if _, err := io.Copy(h, f); err != nil {
+ return fmt.Errorf("failed to hash file: %v", err)
+ }
+ actualHash := fmt.Sprintf("%x", h.Sum(nil))
+
+ if actualHash != expectedHash {
+ return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash, actualHash)
+ }
+ return nil
+}
+
func Run(inputArgs []string, ver string) error {
version = ver
args, err := procArgs(inputArgs)
@@ -67,8 +100,26 @@ func Run(inputArgs []string, ver string) error {
log.Printf("Running version: %s", version)
- // Check if openocd is installed
- if args.ELF != "" {
+ if args.ELF != "" && args.EraseFlash {
+ return fmt.Errorf("must specify either --elf or --erase-flash, not both")
+ }
+
+ // Check if openocd is installed and ELF file exists
+ if args.ELF != "" || args.EraseFlash {
+ if args.ELF != "" {
+ if _, err := os.Stat(args.ELF); err != nil {
+ return fmt.Errorf("elf file not found: %v", err)
+ }
+ }
+ if args.EraseFlash {
+ if _, err := os.Stat(eraseFlashFirmware); err != nil {
+ return fmt.Errorf("erase flash firmware not found: %v", err)
+ }
+ if err := verifyFileHash(eraseFlashFirmware, eraseFlashHash); err != nil {
+ return fmt.Errorf("erase flash firmware hash check failed: %v", err)
+ }
+ log.Println("Erase flash firmware hash verified.")
+ }
cmd := exec.Command("openocd", "--version")
if err := cmd.Run(); err != nil {
log.Println(openOCDNotFoundMessage)
@@ -110,17 +161,21 @@ func Run(inputArgs []string, ver string) error {
return err
}
- log.Println("RP2400 read for programming.")
+ log.Println("RP2040 is ready for programming.")
success := true
- if args.ELF == "" {
+ elfToFlash := args.ELF
+ if args.EraseFlash {
+ elfToFlash = eraseFlashFirmware
+ }
+ if elfToFlash == "" {
log.Println("No elf program provided so assuming programming is done manually.")
log.Println("Press enter when programming is done.")
_, _ = bufio.NewReader(os.Stdin).ReadString('\n')
} else {
- log.Printf("Programming '%s' using 'openocd' file to RP2040\n", args.ELF)
+ log.Printf("Programming '%s' using 'openocd' to RP2040\n", elfToFlash)
cmd := exec.Command("openocd", "-f", "/etc/cacophony/raspberrypi-swd.cfg", "-f", "/target/rp2040.cfg", "-c",
- fmt.Sprintf("program %s verify reset exit", args.ELF))
+ fmt.Sprintf("program %s verify reset exit", elfToFlash))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
@@ -129,6 +184,13 @@ func Run(inputArgs []string, ver string) error {
}
}
+ if args.EraseFlash {
+ // make file to trigger a reprogram next time tc2-agent starts
+ if err := os.WriteFile("/etc/cacophony/program_rp2040", []byte{}, 0644); err != nil {
+ return err
+ }
+ }
+
log.Println("Releasing Run and Boot mode pins.")
if err := runPin.In(gpio.Float, gpio.NoEdge); err != nil {
return err
@@ -137,10 +199,14 @@ func Run(inputArgs []string, ver string) error {
return err
}
+ details := map[string]any{"success": success}
+ if args.EraseFlash {
+ details["eraseFlash"] = true
+ }
eventclient.AddEvent(eventclient.Event{
Timestamp: time.Now(),
Type: "programmingRP2040",
- Details: map[string]interface{}{"success": success},
+ Details: details,
})
if !success {
return errors.New("failed to program RP2040")
diff --git a/internal/tc2-hat-trap-cli/main.go b/internal/tc2-hat-trap-cli/main.go
new file mode 100644
index 0000000..7bddc2f
--- /dev/null
+++ b/internal/tc2-hat-trap-cli/main.go
@@ -0,0 +1,167 @@
+package trapcli
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ comms "github.com/TheCacophonyProject/tc2-hat-controller/internal/tc2-hat-comms"
+
+ goconfig "github.com/TheCacophonyProject/go-config"
+ "github.com/TheCacophonyProject/go-utils/logging"
+ "github.com/TheCacophonyProject/tc2-hat-controller/serialhelper"
+ "github.com/alexflint/go-arg"
+ "periph.io/x/conn/v3/gpio"
+)
+
+var (
+ version = ""
+ log = logging.NewLogger("info")
+)
+
+type Args struct {
+ Command *Command `arg:"subcommand:command" help:"Send a command."`
+ Read *Read `arg:"subcommand:read" help:"Read from a variable."`
+ Write *Write `arg:"subcommand:write" help:"Write to a variable."`
+ Listen *Listen `arg:"subcommand:listen" help:"Continuously listen for messages from the RP2040."`
+ Message *CMDMessage `arg:"subcommand:msg" help:"Send a message to the RP2040."`
+ BaudRate int `arg:"--baud-rate" help:"Baud rate for UART communication."`
+ goconfig.ConfigArgs
+ logging.LogArgs
+}
+
+type CMDMessage struct {
+ ID int `arg:"--id,required" help:"The ID of the message to send."`
+ Type string `arg:"--type,required" help:"The type of message to send."`
+ Payload string `arg:"--payload,required" help:"The payload of the message to send."`
+}
+
+type Command struct {
+ Command string `arg:"--command,required" help:"The command to run."`
+}
+
+type Read struct {
+ Variable string `arg:"--variable,required" help:"The variable to read from."`
+}
+
+type Write struct {
+ Variable string `arg:"--variable,required" help:"The variable to write to."`
+ Value string `arg:"--value,required" help:"The value to write."`
+}
+
+type Listen struct{}
+
+var defaultArgs = Args{
+ BaudRate: 9600,
+}
+
+func sendMessage(msg comms.Message, port *serialhelper.SerialPort) (*comms.Message, error) {
+ line := msg.ToUARTLine()
+ log.Println("Sending:", strings.TrimSpace(line))
+
+ if err := port.Write([]byte(line)); err != nil {
+ return nil, err
+ }
+
+ select {
+ case line, ok := <-port.Lines:
+ if !ok {
+ return nil, fmt.Errorf("serial port closed while waiting for response")
+ }
+ return comms.ParseLine(line)
+ case <-time.After(5 * time.Second):
+ return nil, fmt.Errorf("timeout waiting for response")
+ }
+}
+
+func procArgs(input []string) (Args, error) {
+ args := defaultArgs
+
+ parser, err := arg.NewParser(arg.Config{}, &args)
+ if err != nil {
+ return Args{}, err
+ }
+ err = parser.Parse(input)
+ if errors.Is(err, arg.ErrHelp) {
+ parser.WriteHelp(os.Stdout)
+ os.Exit(0)
+ }
+ if errors.Is(err, arg.ErrVersion) {
+ fmt.Println(version)
+ os.Exit(0)
+ }
+ return args, err
+}
+
+func Run(inputArgs []string, ver string) error {
+ version = ver
+ args, err := procArgs(inputArgs)
+ if err != nil {
+ return fmt.Errorf("failed to parse args: %v", err)
+ }
+ log = logging.NewLogger(args.LogLevel)
+ log.Printf("Running version: %s", version)
+
+ port, err := serialhelper.OpenSerial(gpio.High, gpio.Low, args.BaudRate)
+ if err != nil {
+ return fmt.Errorf("failed to open serial port: %v", err)
+ }
+ defer port.Close()
+
+ switch {
+ case args.Listen != nil:
+ fmt.Println("Listening for messages from RP2040 (Ctrl+C to stop)...")
+ for line := range port.Lines {
+ msg, err := comms.ParseLine(line)
+ if err != nil {
+ fmt.Printf("raw: %s\n", line)
+ log.Warnf("Failed to parse incoming message %q: %v", line, err)
+ continue
+ }
+ log.Println("Received:", msg)
+ }
+ return nil
+
+ case args.Command != nil:
+ data, err := json.Marshal(map[string]string{"command": args.Command.Command})
+ if err != nil {
+ return err
+ }
+ return respond(sendMessage(comms.Message{Type: "command", Payload: string(data)}, port))
+
+ case args.Read != nil:
+ data, err := json.Marshal(map[string]string{"var": args.Read.Variable})
+ if err != nil {
+ return err
+ }
+ return respond(sendMessage(comms.Message{Type: "read", Payload: string(data)}, port))
+
+ case args.Write != nil:
+ data, err := json.Marshal(map[string]string{"var": args.Write.Variable, "val": args.Write.Value})
+ if err != nil {
+ return err
+ }
+ return respond(sendMessage(comms.Message{Type: "write", Payload: string(data)}, port))
+
+ case args.Message != nil:
+ message := comms.Message{ID: args.Message.ID, Type: args.Message.Type, Payload: args.Message.Payload}
+ return respond(sendMessage(message, port))
+
+ default:
+ return fmt.Errorf("no subcommand given")
+ }
+}
+
+func respond(response *comms.Message, err error) error {
+ if err != nil {
+ return err
+ }
+ if response.Type == "NACK" {
+ return fmt.Errorf("NACK response: %s", response.Payload)
+ }
+ fmt.Printf("type=%s payload=%s\n", response.Type, response.Payload)
+ return nil
+}
diff --git a/serialhelper/serialhelper.go b/serialhelper/serialhelper.go
index 176493a..85033d1 100644
--- a/serialhelper/serialhelper.go
+++ b/serialhelper/serialhelper.go
@@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"strings"
+ "sync"
"syscall"
"time"
@@ -158,9 +159,7 @@ func SerialSendReceive(retries int, mul0, mul1 gpio.Level, wait time.Duration, d
if err != nil {
return nil, err
}
-
defer ReleaseSerial(serialFile)
-
c := &serial.Config{Name: "/dev/serial0", Baud: baud, ReadTimeout: time.Second * 5}
serialPort, err := serial.OpenPort(c)
if err != nil {
@@ -168,23 +167,46 @@ func SerialSendReceive(retries int, mul0, mul1 gpio.Level, wait time.Duration, d
}
defer serialPort.Close()
+ start := time.Now()
+ // add a newline at and of data if it is not there already
+ if data[len(data)-1] != '\n' {
+ data = append(data, '\n')
+ }
n, err := serialPort.Write(data)
if err != nil {
return nil, err
}
+
if n != len(data) {
return nil, fmt.Errorf("wrote %d bytes, expected %d", n, len(data))
}
- time.Sleep(time.Second)
- buf := make([]byte, 256)
- n, err = serialPort.Read(buf)
- log.Infof("Received %d bytes", n)
- log.Info("Received:", buf[:n])
- if err != nil {
- return nil, err
- }
- return buf[:n], nil
+ var response []byte
+ var responseTime time.Time
+ firstBits := true
+ buf := make([]byte, 1)
+ for {
+ n, err = serialPort.Read(buf)
+ if err != nil {
+ return nil, err
+ }
+ if n == 0 {
+ continue
+ }
+ if firstBits {
+ responseTime = time.Now()
+ firstBits = false
+ }
+ if buf[0] == '\n' {
+ break
+ }
+ response = append(response, buf[0])
+ }
+ log.Infof("Sent message at %s", start.Format("15:04:05.999"))
+ log.Infof("Received message at %s", responseTime.Format("15:04:05.999"))
+ log.Debugf("Received %d bytes", len(response))
+ log.Debugf("Response time: %s", responseTime)
+ return response, nil
}
func SerialSend(retries int, mul0, mul1 gpio.Level, wait time.Duration, data []byte, baud int) error {
@@ -222,3 +244,93 @@ func SerialSend(retries int, mul0, mul1 gpio.Level, wait time.Duration, data []b
return nil
}
+
+// SerialPort represents a persistent, open serial connection with a background line reader.
+type SerialPort struct {
+ writeMu sync.Mutex
+ port *serial.Port
+ file *os.File
+ Lines chan []byte
+ done chan struct{}
+}
+
+// OpenSerial opens the serial port persistently and starts a background line reader goroutine.
+func OpenSerial(mul0, mul1 gpio.Level, baud int) (*SerialPort, error) {
+ file, err := GetSerial(3, mul0, mul1, time.Second)
+ if err != nil {
+ return nil, err
+ }
+ c := &serial.Config{Name: "/dev/serial0", Baud: baud, ReadTimeout: 100 * time.Millisecond}
+ port, err := serial.OpenPort(c)
+ if err != nil {
+ if rerr := ReleaseSerial(file); rerr != nil {
+ log.Printf("Failed to release serial: %v", rerr)
+ }
+ return nil, err
+ }
+ sp := &SerialPort{
+ port: port,
+ file: file,
+ Lines: make(chan []byte, 16),
+ done: make(chan struct{}),
+ }
+ go sp.readLoop()
+ return sp, nil
+}
+
+// readLoop continuously reads lines from the serial port and sends them to Lines.
+// It exits when Close is called.
+func (s *SerialPort) readLoop() {
+ defer close(s.Lines)
+ buf := make([]byte, 1)
+ var line []byte
+ for {
+ select {
+ case <-s.done:
+ return
+ default:
+ }
+ n, err := s.port.Read(buf)
+ if err != nil || n == 0 {
+ continue
+ }
+ if buf[0] == '\n' {
+ if len(line) > 0 {
+ msg := make([]byte, len(line))
+ copy(msg, line)
+ select {
+ case s.Lines <- msg:
+ case <-s.done:
+ return
+ }
+ line = line[:0]
+ }
+ } else {
+ line = append(line, buf[0])
+ }
+ }
+}
+
+// Write sends data over the serial port. Appends a newline if not already present.
+func (s *SerialPort) Write(data []byte) error {
+ s.writeMu.Lock()
+ defer s.writeMu.Unlock()
+ if len(data) == 0 || data[len(data)-1] != '\n' {
+ data = append(data, '\n')
+ }
+ n, err := s.port.Write(data)
+ if err != nil {
+ return err
+ }
+ if n != len(data) {
+ return fmt.Errorf("wrote %d bytes, expected %d", n, len(data))
+ }
+ return nil
+}
+
+// Close stops the background reader and releases the serial port.
+func (s *SerialPort) Close() error {
+ close(s.done)
+ s.port.Close()
+ return ReleaseSerial(s.file)
+}