diff --git a/cmd/ax/harness.go b/cmd/ax/harness.go index 51c53eb..71275dc 100644 --- a/cmd/ax/harness.go +++ b/cmd/ax/harness.go @@ -82,12 +82,11 @@ func NewHarnessServiceServer() *HarnessServiceServer { } // Connect implements the bidirectional gRPC streaming capability. -// It receives client inputs and responds only with "Hello world". +// It receives client inputs and responds with "hello world" unless the input message text is "go_away". +// TODO(params): Update the implementation to be a proper one. func (s *HarnessServiceServer) Connect(stream proto.HarnessService_ConnectServer) error { - // TODO: Connect will be implemented to serve the built in harnesses - // as an isolated actor. for { - _, err := stream.Recv() + req, err := stream.Recv() if err == io.EOF { return nil } @@ -95,6 +94,22 @@ func (s *HarnessServiceServer) Connect(stream proto.HarnessService_ConnectServer return err } + shouldGoAway := false + for _, m := range req.Messages { + if textBlock, ok := m.Content.Type.(*proto.Content_Text); ok { + // TODO(params): Replace this with a proper protocol for go away. + if textBlock.Text.Text == "go_away" { + shouldGoAway = true + break + } + } + } + + if shouldGoAway { + log.Println("Received 'go_away' message, closing stream...") + return nil + } + err = stream.Send(&proto.HarnessMessage{ Messages: []*proto.Message{ { @@ -102,7 +117,7 @@ func (s *HarnessServiceServer) Connect(stream proto.HarnessService_ConnectServer Content: &proto.Content{ Type: &proto.Content_Text{ Text: &proto.TextContent{ - Text: "Hello world", + Text: "hello world", }, }, }, diff --git a/cmd/ax/harnessclient.go b/cmd/ax/harnessclient.go new file mode 100644 index 0000000..e02c361 --- /dev/null +++ b/cmd/ax/harnessclient.go @@ -0,0 +1,128 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package main implements a simple client for the fake HarnessService. +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "log" + "os" + + "github.com/google/ax/proto" + "github.com/spf13/cobra" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +var ( + harnessServerAddr string +) + +var harnessClientCmd = &cobra.Command{ + Use: "harnessclient", + Short: "Run the harness client to connect to the server", + Hidden: true, + RunE: runHarnessClient, +} + +func init() { + harnessClientCmd.Flags().StringVar(&harnessServerAddr, "server", "localhost:50053", "The server address for the gRPC HarnessService.") + rootCmd.AddCommand(harnessClientCmd) +} + +func runHarnessClient(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + log.Printf("Connecting to HarnessService at %s...", harnessServerAddr) + conn, err := grpc.NewClient(harnessServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return fmt.Errorf("failed to connect to server: %v", err) + } + defer conn.Close() + + client := proto.NewHarnessServiceClient(conn) + + stream, err := client.Connect(ctx) + if err != nil { + return fmt.Errorf("Failed to open connection stream: %v", err) + } + + scanner := bufio.NewScanner(os.Stdin) + fmt.Println("Interactive client started. Type your messages below.") + fmt.Println("Type 'go_away' to close the stream and exit.") + for { + fmt.Print("\nClient > ") + if !scanner.Scan() { + break + } + text := scanner.Text() + if text == "" { + continue + } + + msg := &proto.HarnessMessage{ + Messages: []*proto.Message{ + { + Role: "user", + Content: &proto.Content{ + Type: &proto.Content_Text{ + Text: &proto.TextContent{ + Text: text, + }, + }, + }, + }, + }, + } + + if err := stream.Send(msg); err != nil { + return fmt.Errorf("Failed to send message: %v", err) + } + // TODO(params): Replace this with a proper protocol for go away. + if text == "go_away" { + log.Println("Sending 'go_away' to close the stream...") + break + } + + resp, err := stream.Recv() + if err != nil { + return fmt.Errorf("Failed to receive response: %v", err) + } + + for i, m := range resp.Messages { + var textContent string + if textBlock, ok := m.Content.Type.(*proto.Content_Text); ok { + textContent = textBlock.Text.Text + } + fmt.Printf("Server > message[%d] (%s): %s\n", i, m.Role, textContent) + } + } + + // Close send side to signal request completion + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("Failed to close send side of stream: %v", err) + } + + log.Println("Waiting for final stream EOF...") + _, err = stream.Recv() + if err != io.EOF { + return fmt.Errorf("Expected EOF from server, got: %v", err) + } + log.Println("Stream closed successfully by server.") + return nil +}