diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 1b29b3785..a71484f2f 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -2141,3 +2141,72 @@ func (a *FlowableActivity) MigratePostgresTableOIDs( return nil } + +func (a *FlowableActivity) PeerDBPGAutomatedSchemaDump(ctx context.Context, env map[string]string) (bool, error) { + return internal.PeerDBPGAutomatedSchemaDump(ctx, env) +} + +func (a *FlowableActivity) RunPgDumpSchema( + ctx context.Context, + input *protos.RunPgDumpSchemaInput, +) (bool, error) { + logger := internal.LoggerFromCtx(ctx) + ctx = context.WithValue(ctx, shared.FlowNameKey, input.FlowName) + + srcPeer, err := connectors.LoadPeer(ctx, a.CatalogPool, input.SourceName) + if err != nil { + return false, a.Alerter.LogFlowError(ctx, input.FlowName, fmt.Errorf("failed to load source peer: %w", err)) + } + + dstPeer, err := connectors.LoadPeer(ctx, a.CatalogPool, input.DestinationName) + if err != nil { + return false, a.Alerter.LogFlowError(ctx, input.FlowName, fmt.Errorf("failed to load destination peer: %w", err)) + } + + srcPgConfig, ok := srcPeer.Config.(*protos.Peer_PostgresConfig) + if !ok { + return false, a.Alerter.LogFlowError(ctx, input.FlowName, fmt.Errorf("source peer %s is not a PostgreSQL peer", input.SourceName)) + } + + dstPgConfig, ok := dstPeer.Config.(*protos.Peer_PostgresConfig) + if !ok { + return false, a.Alerter.LogFlowError(ctx, input.FlowName, + fmt.Errorf("destination peer %s is not a PostgreSQL peer", input.DestinationName)) + } + + // skip schema migration for peers using SSH tunnels + if srcPgConfig.PostgresConfig.SshConfig != nil { + logger.Info("skipping pg_dump schema migration: source peer uses SSH tunnel") + return false, nil + } + if dstPgConfig.PostgresConfig.SshConfig != nil { + logger.Info("skipping pg_dump schema migration: destination peer uses SSH tunnel") + return false, nil + } + + // skip schema migration for non-password auth (e.g. IAM) + if srcPgConfig.PostgresConfig.AuthType != protos.PostgresAuthType_POSTGRES_PASSWORD { + logger.Info("skipping pg_dump schema migration: source peer uses non-password auth") + return false, nil + } + if dstPgConfig.PostgresConfig.AuthType != protos.PostgresAuthType_POSTGRES_PASSWORD { + logger.Info("skipping pg_dump schema migration: destination peer uses non-password auth") + return false, nil + } + + logger.Info("running pg_dump schema migration from source to destination", + slog.String("source", input.SourceName), slog.String("destination", input.DestinationName)) + a.Alerter.LogFlowInfo(ctx, input.FlowName, + fmt.Sprintf("starting pg_dump schema migration from %s to %s", input.SourceName, input.DestinationName)) + + start := time.Now() + if err := connpostgres.RunPgDumpSchema(ctx, srcPgConfig.PostgresConfig, dstPgConfig.PostgresConfig); err != nil { + return false, a.Alerter.LogFlowError(ctx, input.FlowName, fmt.Errorf("pg_dump schema migration failed: %w", err)) + } + + elapsed := time.Since(start).Round(time.Millisecond) + logger.Info("pg_dump schema migration completed successfully", slog.Duration("elapsed", elapsed)) + a.Alerter.LogFlowInfo(ctx, input.FlowName, + fmt.Sprintf("pg_dump schema migration completed successfully in %s", elapsed)) + return true, nil +} diff --git a/flow/connectors/postgres/pgdump_schema.go b/flow/connectors/postgres/pgdump_schema.go new file mode 100644 index 000000000..7c77560e2 --- /dev/null +++ b/flow/connectors/postgres/pgdump_schema.go @@ -0,0 +1,302 @@ +package connpostgres + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "regexp" + "strconv" + + "github.com/PeerDB-io/peerdb/flow/generated/protos" +) + +// pg_dump from newer Postgres versions emits statements that older +// destinations don't recognize: +// - SET transaction_timeout = 0; (PG17+ session GUC) +// - \restrict / \unrestrict (pg_dump 17.6+ psql meta-commands +// that gate replay against an unrelated psql session; older psql treats +// them as unknown backslash commands and aborts under ON_ERROR_STOP) +// +// These are session/replay housekeeping and safe to drop on the wire so we +// keep ON_ERROR_STOP=1 for genuine DDL failures while remaining cross-version. +var incompatibleLineRE = regexp.MustCompile(`^(SET\s+transaction_timeout\s*=|\\(?:un)?restrict(\s|$))`) + +// RunPgDumpSchema streams a schema-only pg_dump from source directly into psql +// on the destination, piping stdout into stdin without intermediate files. +func RunPgDumpSchema(ctx context.Context, srcConfig *protos.PostgresConfig, dstConfig *protos.PostgresConfig) error { + if err := pipeCommand(ctx, srcConfig, dstConfig, "pg_dump", buildPgDumpArgs(srcConfig)); err != nil { + return fmt.Errorf("pg_dump schema migration failed: %w", err) + } + + return nil +} + +// pipeCommand runs srcBinary with the given args, piping its stdout into psql on the destination. +func pipeCommand( + ctx context.Context, + srcConfig *protos.PostgresConfig, + dstConfig *protos.PostgresConfig, + srcBinary string, + srcArgs []string, +) error { + psqlArgs := buildPsqlArgs(dstConfig) + + srcCmd := exec.CommandContext(ctx, srcBinary, srcArgs...) + psqlCmd := exec.CommandContext(ctx, "psql", psqlArgs...) + + // set PGPASSWORD for each command via separate env slices + srcCmd.Env = append(os.Environ(), "PGPASSWORD="+srcConfig.Password) + psqlCmd.Env = append(os.Environ(), "PGPASSWORD="+dstConfig.Password) + + // handle TLS env vars + appendTLSEnv(ctx, srcCmd, srcConfig) + appendTLSEnv(ctx, psqlCmd, dstConfig) + + return runPipeline(ctx, srcCmd, psqlCmd, srcBinary, "psql", filterIncompatibleLines) +} + +// filterIncompatibleLines copies r->w line by line, dropping statements that +// are valid in newer pg_dump output but rejected by older psql/destinations. +func filterIncompatibleLines(ctx context.Context, r io.Reader, w io.Writer) error { + br := bufio.NewReaderSize(r, 64*1024) + for { + line, err := br.ReadBytes('\n') + if len(line) > 0 { + if !incompatibleLineRE.Match(line) { + if _, werr := w.Write(line); werr != nil { + return werr + } + } else { + slog.DebugContext(ctx, "dropping incompatible line from pg_dump stream", + slog.String("line", string(bytes.TrimRight(line, "\n")))) + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +// runPipeline wires srcCmd's stdout into dstCmd's stdin (optionally through a +// filter goroutine) and waits for both processes. +// +// Pipe topology: +// +// without filter: src.stdout -> srcW |--pipe--| srcR -> dst.stdin +// with filter: src.stdout -> srcW |--pipe--| srcR -> filter -> dstW |--pipe--| dstR -> dst.stdin +// +// File descriptor ownership matters here -- if the parent keeps a write end +// open after the child consumer dies, the producer can hang forever on a +// blocked write. We close each fd as soon as the child or filter goroutine +// owns it. +func runPipeline( + ctx context.Context, + srcCmd, dstCmd *exec.Cmd, + srcName, dstName string, + filter func(context.Context, io.Reader, io.Writer) error, +) error { + srcR, srcW, err := os.Pipe() + if err != nil { + return fmt.Errorf("create src pipe: %w", err) + } + srcCmd.Stdout = srcW + + var ( + dstR, dstW *os.File + filterDone chan error + ) + if filter == nil { + dstCmd.Stdin = srcR + } else { + dstR, dstW, err = os.Pipe() + if err != nil { + srcR.Close() + srcW.Close() + return fmt.Errorf("create dst pipe: %w", err) + } + dstCmd.Stdin = dstR + filterDone = make(chan error, 1) + } + + var srcStderr, dstStderr bytes.Buffer + srcCmd.Stderr = &srcStderr + dstCmd.Stderr = &dstStderr + + // Start dst first so it's ready to read. + if err := dstCmd.Start(); err != nil { + srcR.Close() + srcW.Close() + if dstW != nil { + dstR.Close() + dstW.Close() + } + return fmt.Errorf("start %s: %w", dstName, err) + } + // dst owns its stdin fd in its child; close our copy. + if filter == nil { + srcR.Close() + } else { + dstR.Close() + } + + if err := srcCmd.Start(); err != nil { + srcW.Close() + if dstW != nil { + // filter never started; close its writer so dst sees EOF. + dstW.Close() + // and the read side we still hold if filter==nil path wasn't taken. + if filter != nil { + srcR.Close() + } + } + _ = dstCmd.Process.Kill() + _ = dstCmd.Wait() + return fmt.Errorf("start %s: %w", srcName, err) + } + // src owns its stdout fd in its child; close our copy. + srcW.Close() + + // Run the filter goroutine if configured. It bridges srcR -> dstW. + if filter != nil { + go func() { + err := filter(ctx, srcR, dstW) + // Always close both ends so the producer/consumer unblock. + srcR.Close() + dstW.Close() + filterDone <- err + }() + } + + srcDone := make(chan error, 1) + dstDone := make(chan error, 1) + go func() { srcDone <- srcCmd.Wait() }() + go func() { dstDone <- dstCmd.Wait() }() + + var ( + srcErr, dstErr error + srcKilled, dstKilled bool + ) + for range 2 { + select { + case err := <-srcDone: + srcErr = err + if err != nil && dstCmd.ProcessState == nil { + _ = dstCmd.Process.Kill() + dstKilled = true + } + case err := <-dstDone: + dstErr = err + if srcCmd.ProcessState == nil { + // dst exited (success or failure) while src is still running; + // kill src so it doesn't block on a pipe with no reader. + _ = srcCmd.Process.Kill() + srcKilled = true + } + } + } + + // Wait for the filter to finish so we surface any I/O error and so the + // goroutine doesn't outlive this function. + var filterErr error + if filterDone != nil { + filterErr = <-filterDone + } + + // Report the original cause, not the side we killed in response. + if dstErr != nil && !dstKilled { + return fmt.Errorf("%s failed: %w\nstderr:\n%s", dstName, dstErr, dstStderr.String()) + } + if srcErr != nil && !srcKilled { + return fmt.Errorf("%s failed: %w\nstderr:\n%s", srcName, srcErr, srcStderr.String()) + } + if filterErr != nil { + return fmt.Errorf("filter failed: %w", filterErr) + } + // Fallback: both sides killed (e.g. ctx cancel) — surface whichever error we have. + if srcErr != nil { + return fmt.Errorf("%s failed: %w\nstderr:\n%s", srcName, srcErr, srcStderr.String()) + } + if dstErr != nil { + return fmt.Errorf("%s failed: %w\nstderr:\n%s", dstName, dstErr, dstStderr.String()) + } + return nil +} + +func buildPgDumpArgs(config *protos.PostgresConfig) []string { + port := config.Port + if port == 0 { + port = 5432 + } + + args := []string{ + "--schema-only", + "--no-owner", + "--no-privileges", + "-h", config.Host, + "-p", strconv.FormatUint(uint64(port), 10), + "-d", config.Database, + } + if config.User != "" { + args = append(args, "-U", config.User) + } + return args +} + +func buildPsqlArgs(config *protos.PostgresConfig) []string { + port := config.Port + if port == 0 { + port = 5432 + } + + args := []string{ + "-h", config.Host, + "-p", strconv.FormatUint(uint64(port), 10), + "-d", config.Database, + // Wrap the entire dump in a single transaction so partial failures + // roll back cleanly (makes the activity safely retryable) and avoid + // per-statement autocommit overhead on high-latency links. + "--single-transaction", + // Without this, psql logs errors to stderr but exits 0, so a half- + // applied schema would be reported as success. ON_ERROR_STOP=1 makes + // psql exit non-zero on the first failed statement. + "-v", "ON_ERROR_STOP=1", + // Quiet informational chatter; errors still go to stderr. + "--quiet", + } + if config.User != "" { + args = append(args, "-U", config.User) + } + return args +} + +func appendTLSEnv(ctx context.Context, cmd *exec.Cmd, config *protos.PostgresConfig) { + if config.RequireTls { + cmd.Env = append(cmd.Env, "PGSSLMODE=require") + + if config.RootCa != nil && *config.RootCa != "" { + // write root CA to a temp file + tmpFile, err := os.CreateTemp("", "peerdb-root-ca-*.pem") + if err != nil { + slog.WarnContext(ctx, "failed to create temp file for root CA, skipping sslrootcert", slog.Any("error", err)) + return + } + if _, err := tmpFile.WriteString(*config.RootCa); err != nil { + slog.WarnContext(ctx, "failed to write root CA to temp file", slog.Any("error", err)) + tmpFile.Close() + os.Remove(tmpFile.Name()) + return + } + tmpFile.Close() + cmd.Env = append(cmd.Env, "PGSSLROOTCERT="+tmpFile.Name()) + // note: temp file is cleaned up when the process exits + } + } +} diff --git a/flow/connectors/postgres/pgdump_schema_test.go b/flow/connectors/postgres/pgdump_schema_test.go new file mode 100644 index 000000000..86d4bd5af --- /dev/null +++ b/flow/connectors/postgres/pgdump_schema_test.go @@ -0,0 +1,270 @@ +package connpostgres + +import ( + "bytes" + "context" + "errors" + "os/exec" + "runtime" + "strings" + "testing" + "time" +) + +// requireUnix skips the test on platforms without the shell utilities used here. +func requireUnix(t *testing.T) { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("requires unix shell utilities") + } +} + +func TestRunPipeline_HappyPath(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + src := exec.CommandContext(ctx, "sh", "-c", "printf 'hello world'") + var dstOut bytes.Buffer + dst := exec.CommandContext(ctx, "cat") + dst.Stdout = &dstOut + + if err := runPipeline(ctx, src, dst, "src", "dst", nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := dstOut.String(); got != "hello world" { + t.Fatalf("dst stdout = %q, want %q", got, "hello world") + } +} + +func TestRunPipeline_SrcStartFails(t *testing.T) { + ctx := t.Context() + + src := exec.CommandContext(ctx, "/nonexistent/peerdb-test-binary") + dst := exec.CommandContext(ctx, "cat") + + err := runPipeline(ctx, src, dst, "src", "dst", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "start src") { + t.Fatalf("error %q does not mention src start failure", err) + } + // dst should have been killed and reaped; ProcessState should be set. + if dst.ProcessState == nil { + t.Fatal("dst was not reaped after src start failure") + } +} + +func TestRunPipeline_DstStartFails(t *testing.T) { + ctx := t.Context() + + src := exec.CommandContext(ctx, "echo", "hi") + dst := exec.CommandContext(ctx, "/nonexistent/peerdb-test-binary") + + err := runPipeline(ctx, src, dst, "src", "dst", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "start dst") { + t.Fatalf("error %q does not mention dst start failure", err) + } + // src must not have been started. + if src.ProcessState != nil { + t.Fatal("src should not have been started when dst failed to start") + } +} + +func TestRunPipeline_SrcExitsNonZero(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + // write some output then exit with error + src := exec.CommandContext(ctx, "sh", "-c", "echo partial; exit 7") + dst := exec.CommandContext(ctx, "cat") + dst.Stdout = &bytes.Buffer{} + + err := runPipeline(ctx, src, dst, "src", "dst", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "src failed") { + t.Fatalf("error %q does not mention src failure", err) + } +} + +func TestRunPipeline_DstExitsNonZero(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + src := exec.CommandContext(ctx, "sh", "-c", "echo hi") + // exit 3 immediately, ignoring stdin + dst := exec.CommandContext(ctx, "sh", "-c", "exit 3") + + err := runPipeline(ctx, src, dst, "src", "dst", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + // src succeeded so error must be from dst + if !strings.Contains(err.Error(), "dst failed") { + t.Fatalf("error %q does not mention dst failure", err) + } +} + +// TestRunPipeline_SrcFailsWhileDstSlow verifies the deadlock-prevention fix: +// if src exits non-zero while dst is still reading slowly, dst is killed so +// runPipeline returns promptly instead of waiting for dst to finish its work. +func TestRunPipeline_SrcFailsWhileDstSlow(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + // src writes a small amount (fits in pipe buffer, no blocking) then exits non-zero. + src := exec.CommandContext(ctx, "sh", "-c", "echo hi; exit 9") + // dst is a single process (no shell-spawned children) that doesn't read stdin + // and won't exit on its own. We expect runPipeline to kill it after src fails. + // Note: we deliberately avoid `sh -c "sleep 30; cat"` here -- when sh forks a + // child, that child inherits sh's stderr fd, and Go's exec.Wait blocks + // draining stderr until the inherited fd is closed (i.e. for the full sleep). + // psql doesn't fork children, so this matches real behavior. + dst := exec.CommandContext(ctx, "sleep", "30") + + start := time.Now() + done := make(chan error, 1) + go func() { done <- runPipeline(ctx, src, dst, "src", "dst", nil) }() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error from src failure") + } + if !strings.Contains(err.Error(), "src failed") { + t.Fatalf("expected src failure, got %v", err) + } + if elapsed := time.Since(start); elapsed > 5*time.Second { + t.Fatalf("runPipeline took %v -- dst was not killed promptly after src failure", elapsed) + } + case <-time.After(10 * time.Second): + t.Fatal("runPipeline hung -- dst was not killed after src failure") + } +} + +// TestRunPipeline_DstExitsWhileSrcWriting verifies the inverse: if dst exits +// early while src is producing lots of data, src is killed so it doesn't hang +// forever blocked on a write to a closed pipe (would normally get SIGPIPE, +// but we explicitly kill to be safe / to surface the failure quickly). +func TestRunPipeline_DstExitsWhileSrcWriting(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + // src tries to stream a lot of data + src := exec.CommandContext(ctx, "sh", "-c", "yes peerdb | head -c 10000000") + // dst exits immediately without reading + dst := exec.CommandContext(ctx, "sh", "-c", "exit 2") + + start := time.Now() + done := make(chan error, 1) + go func() { done <- runPipeline(ctx, src, dst, "src", "dst", nil) }() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error from dst failure") + } + // We prefer dst's error since src's failure is just a downstream symptom. + if !strings.Contains(err.Error(), "dst failed") { + t.Fatalf("expected dst failure, got %v", err) + } + if elapsed := time.Since(start); elapsed > 5*time.Second { + t.Fatalf("runPipeline took %v -- src was not killed promptly after dst exit", elapsed) + } + case <-time.After(10 * time.Second): + t.Fatal("runPipeline hung -- src was not killed after dst exited") + } +} + +// TestRunPipeline_LargeStream verifies that streaming more than the kernel +// pipe buffer (typically 64KB on Linux) works without deadlock. +func TestRunPipeline_LargeStream(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + const size = 2 * 1024 * 1024 // 2 MiB + // #nosec G204 -- test-only, constant arguments + src := exec.CommandContext(ctx, "sh", "-c", "yes a | head -c 2097152") + var out bytes.Buffer + dst := exec.CommandContext(ctx, "cat") + dst.Stdout = &out + + if err := runPipeline(ctx, src, dst, "src", "dst", nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out.Len() != size { + t.Fatalf("dst received %d bytes, want %d", out.Len(), size) + } +} + +func TestRunPipeline_ContextCancel(t *testing.T) { + requireUnix(t) + ctx, cancel := context.WithCancel(t.Context()) + + // Use exec'd binaries directly (not `sh -c "..."`). When sh is run with + // a single argument, many shells fork a child for the command rather than + // exec-replacing themselves. That child inherits sh's stderr fd, and Go's + // exec.Wait blocks draining stderr until every fd holder closes it -- so + // CommandContext killing sh isn't enough; the child keeps stderr open and + // Wait hangs. Using a single-process command avoids the inheritance. + src := exec.CommandContext(ctx, "sleep", "30") + dst := exec.CommandContext(ctx, "cat") + dst.Stdout = &bytes.Buffer{} + + done := make(chan error, 1) + go func() { done <- runPipeline(ctx, src, dst, "src", "dst", nil) }() + + // give them a moment to start + time.Sleep(100 * time.Millisecond) + cancel() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error after context cancel") + } + // CommandContext kills the process; just ensure we got back. + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) && !strings.Contains(err.Error(), "killed") && + !strings.Contains(err.Error(), "signal") { + // any non-nil error is acceptable here; we're mostly checking we don't hang + t.Logf("got error after cancel: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("runPipeline did not return after context cancel") + } +} + +// TestRunPipeline_FilterStripsLines verifies the filter goroutine drops +// matching lines and forwards the rest. Covers SET transaction_timeout (PG17+) +// and \restrict / \unrestrict psql meta-commands (pg_dump 17.6+). +func TestRunPipeline_FilterStripsLines(t *testing.T) { + requireUnix(t) + ctx := t.Context() + + input := "SELECT 1;\n" + + "SET transaction_timeout = 0;\n" + + "\\restrict abc123\n" + + "CREATE TABLE t(id int);\n" + + "\\unrestrict abc123\n" + + "SELECT 2;\n" + src := exec.CommandContext(ctx, "printf", "%s", input) + var out bytes.Buffer + dst := exec.CommandContext(ctx, "cat") + dst.Stdout = &out + + if err := runPipeline(ctx, src, dst, "src", "dst", filterIncompatibleLines); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got := out.String() + want := "SELECT 1;\nCREATE TABLE t(id int);\nSELECT 2;\n" + if got != want { + t.Fatalf("filtered output = %q, want %q", got, want) + } +} diff --git a/flow/connectors/postgres/postgres_destination.go b/flow/connectors/postgres/postgres_destination.go index 2b87fc0aa..9314d798d 100644 --- a/flow/connectors/postgres/postgres_destination.go +++ b/flow/connectors/postgres/postgres_destination.go @@ -329,7 +329,6 @@ func (c *PostgresConnector) normalizeBatch( if _, err := tx.Exec(ctx, setSessionReplicaRoleSQL); err != nil { return 0, fmt.Errorf("failed to set session_replication_role to replica: %w", err) } - c.logger.Info("set session_replication_role to replica for PG type system normalize") break } } diff --git a/flow/connectors/postgres/validate.go b/flow/connectors/postgres/validate.go index 4d942747a..dcec932bb 100644 --- a/flow/connectors/postgres/validate.go +++ b/flow/connectors/postgres/validate.go @@ -278,6 +278,10 @@ func (c *PostgresConnector) ValidateMirrorDestination( return nil // no need to validate schema for resync, as we will create or replace the tables } + if cfg.System == protos.TypeSystem_PG && cfg.Env["PEERDB_PG_AUTOMATED_SCHEMA_DUMP"] == "true" { + return nil // pg_dump will create the schema and tables on the destination + } + // Validate that all source columns exist in destination tables checkedSchemas := make(map[string]struct{}) for _, tableMapping := range cfg.TableMappings { diff --git a/flow/e2e/pg_schema_dump_test.go b/flow/e2e/pg_schema_dump_test.go new file mode 100644 index 000000000..bcc587d0f --- /dev/null +++ b/flow/e2e/pg_schema_dump_test.go @@ -0,0 +1,475 @@ +package e2e + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peerdb/flow/generated/protos" + "github.com/PeerDB-io/peerdb/flow/internal" +) + +// setupDedicatedPgDumpSource creates a fresh database on the primary PG instance +// to serve as the source for schema-dump tests. pg_dump dumps a whole database, +// so sharing the source DB with other parallel tests caused the dump to include +// every concurrent test's e2e_test_ schema, blowing past the workflow +// SETUP timeout. A dedicated source DB keeps the dump scoped to this test. +// +// Returns a connection to the new DB, the registered source peer name, and the +// schema name to use within it. All resources are cleaned up via t.Cleanup. +func setupDedicatedPgDumpSource(t *testing.T, suffix string) (*pgx.Conn, string, string) { + t.Helper() + + srcCfg := internal.GetAncillaryPostgresConfigFromEnv() + srcDBName := "e2e_pgdump_src_" + suffix + srcSchema := "e2e_test_" + suffix + + bootstrapStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", + srcCfg.Host, srcCfg.Port, srcCfg.User, srcCfg.Password, srcCfg.Database) + bootstrap, err := pgx.Connect(t.Context(), bootstrapStr) + require.NoError(t, err) + _, err = bootstrap.Exec(t.Context(), "CREATE DATABASE "+srcDBName) + require.NoError(t, err) + bootstrap.Close(t.Context()) + + srcConnStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", + srcCfg.Host, srcCfg.Port, srcCfg.User, srcCfg.Password, srcDBName) + srcConn, err := pgx.Connect(t.Context(), srcConnStr) + require.NoError(t, err) + _, err = srcConn.Exec(t.Context(), "CREATE SCHEMA "+srcSchema) + require.NoError(t, err) + + srcPeerCfg := internal.GetAncillaryPostgresConfigFromEnv() + srcPeerCfg.Database = srcDBName + srcPeerName := "pgdump_src_" + suffix + CreatePeer(t, &protos.Peer{ + Name: srcPeerName, + Type: protos.DBType_POSTGRES, + Config: &protos.Peer_PostgresConfig{PostgresConfig: srcPeerCfg}, + }) + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + srcConn.Close(ctx) + dropConn, err := pgx.Connect(ctx, bootstrapStr) + if err != nil { + t.Logf("failed to connect for source DB cleanup: %v", err) + return + } + defer dropConn.Close(ctx) + if _, err := dropConn.Exec(ctx, "DROP DATABASE IF EXISTS "+srcDBName+" WITH (FORCE)"); err != nil { + t.Logf("failed to drop source database %s: %v", srcDBName, err) + } + }) + + return srcConn, srcPeerName, srcSchema +} + +func (s PeerFlowE2ETestSuitePG) Test_PG_Schema_Dump_And_CDC() { + srcConn, srcPeerName, srcSchema := setupDedicatedPgDumpSource(s.t, s.suffix) + + dstDBName := "e2e_pgdump_" + s.suffix + + // create destination database on the same PG instance + _, err := s.Conn().Exec(s.t.Context(), "CREATE DATABASE "+dstDBName) + require.NoError(s.t, err) + s.t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cfg := internal.GetAncillaryPostgresConfigFromEnv() + connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=postgres", + cfg.Host, cfg.Port, cfg.User, cfg.Password) + dropConn, err := pgx.Connect(ctx, connStr) + if err != nil { + s.t.Logf("failed to connect for cleanup: %v", err) + return + } + defer dropConn.Close(ctx) + if _, err := dropConn.Exec(ctx, "DROP DATABASE IF EXISTS "+dstDBName+" WITH (FORCE)"); err != nil { + s.t.Logf("failed to drop destination database %s: %v", dstDBName, err) + } + }) + + // connect to destination database for verification later + dstCfg := internal.GetAncillaryPostgresConfigFromEnv() + dstConnStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", + dstCfg.Host, dstCfg.Port, dstCfg.User, dstCfg.Password, dstDBName) + dstConn, err := pgx.Connect(s.t.Context(), dstConnStr) + require.NoError(s.t, err) + s.t.Cleanup(func() { dstConn.Close(s.t.Context()) }) + + // create a destination peer pointing to the new database + dstPeerCfg := internal.GetAncillaryPostgresConfigFromEnv() + dstPeerCfg.Database = dstDBName + dstPeerName := "pgdump_dst_" + s.suffix + dstPeer := &protos.Peer{ + Name: dstPeerName, + Type: protos.DBType_POSTGRES, + Config: &protos.Peer_PostgresConfig{ + PostgresConfig: dstPeerCfg, + }, + } + CreatePeer(s.t, dstPeer) + + // --- set up rich schema on source --- + + // create custom enum type in the source schema + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "CREATE TYPE %s.color AS ENUM ('red', 'green', 'blue')", srcSchema)) + require.NoError(s.t, err) + + // create parent table with various column types + parentTable := srcSchema + ".parent_tbl" + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + color %s.color NOT NULL DEFAULT 'red', + score NUMERIC(10,2), + metadata JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + )`, parentTable, srcSchema)) + require.NoError(s.t, err) + + // create unique index on parent + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "CREATE UNIQUE INDEX idx_parent_name ON %s (name)", parentTable)) + require.NoError(s.t, err) + + // create child table with foreign key referencing parent + childTable := srcSchema + ".child_tbl" + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id SERIAL PRIMARY KEY, + parent_id INT NOT NULL REFERENCES %s(id), + value TEXT, + tags TEXT[] + )`, childTable, parentTable)) + require.NoError(s.t, err) + + // create btree index on child + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "CREATE INDEX idx_child_parent ON %s (parent_id)", childTable)) + require.NoError(s.t, err) + + // insert initial data for snapshot + for i := 1; i <= 5; i++ { + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "INSERT INTO %s (name, color, score, metadata) VALUES ($1, $2, $3, $4)", + parentTable), + fmt.Sprintf("item_%d", i), + []string{"red", "green", "blue"}[i%3], + float64(i)*10.5, + fmt.Sprintf(`{"key": "val_%d"}`, i), + ) + require.NoError(s.t, err) + } + + for i := 1; i <= 10; i++ { + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "INSERT INTO %s (parent_id, value, tags) VALUES ($1, $2, $3)", + childTable), + (i%5)+1, + fmt.Sprintf("child_val_%d", i), + fmt.Sprintf("{tag_%d,common}", i), + ) + require.NoError(s.t, err) + } + + // source and dest table identifiers — use same schema-qualified names + // since pg_dump recreates the schema on the destination + srcParent := parentTable + dstParent := parentTable + srcChild := childTable + dstChild := childTable + + config := &protos.FlowConnectionConfigs{ + FlowJobName: s.attachSuffix("test_pgdump"), + DestinationName: dstPeerName, + TableMappings: []*protos.TableMapping{ + { + SourceTableIdentifier: srcParent, + DestinationTableIdentifier: dstParent, + }, + { + SourceTableIdentifier: srcChild, + DestinationTableIdentifier: dstChild, + }, + }, + SourceName: srcPeerName, + MaxBatchSize: 100, + DoInitialSnapshot: true, + System: protos.TypeSystem_PG, + SoftDeleteColName: "", + SyncedAtColName: "", + Env: map[string]string{ + "PEERDB_PG_AUTOMATED_SCHEMA_DUMP": "true", + }, + } + + tc := NewTemporalClient(s.t) + env := ExecutePeerflow(s.t, tc, config) + SetupCDCFlowStatusQuery(s.t, env, config) + + // wait for initial snapshot to complete + EnvWaitFor(s.t, env, 3*time.Minute, "initial load parent", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+dstParent).Scan(&count) + return err == nil && count == 5 + }) + EnvWaitFor(s.t, env, 3*time.Minute, "initial load child", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+dstChild).Scan(&count) + return err == nil && count == 10 + }) + + // --- verify schema objects on destination --- + + // verify enum type exists on destination + var enumExists bool + err = dstConn.QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE t.typname = 'color' AND n.nspname = $1 + )`, srcSchema).Scan(&enumExists) + require.NoError(s.t, err) + require.True(s.t, enumExists, "enum type 'color' should exist on destination") + + // verify unique index on parent table + var idxParentExists bool + err = dstConn.QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE schemaname = $1 AND tablename = 'parent_tbl' AND indexname = 'idx_parent_name' + )`, srcSchema).Scan(&idxParentExists) + require.NoError(s.t, err) + require.True(s.t, idxParentExists, "unique index idx_parent_name should exist on destination") + + // verify btree index on child table + var idxChildExists bool + err = dstConn.QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE schemaname = $1 AND tablename = 'child_tbl' AND indexname = 'idx_child_parent' + )`, srcSchema).Scan(&idxChildExists) + require.NoError(s.t, err) + require.True(s.t, idxChildExists, "btree index idx_child_parent should exist on destination") + + // verify foreign key constraint on child table + var fkExists bool + err = dstConn.QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM information_schema.table_constraints + WHERE constraint_type = 'FOREIGN KEY' + AND table_schema = $1 + AND table_name = 'child_tbl' + )`, srcSchema).Scan(&fkExists) + require.NoError(s.t, err) + require.True(s.t, fkExists, "foreign key on child_tbl should exist on destination") + + // --- CDC test: insert more rows and verify replication --- + + // insert more parent rows + for i := 6; i <= 8; i++ { + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "INSERT INTO %s (name, color, score, metadata) VALUES ($1, $2, $3, $4)", + parentTable), + fmt.Sprintf("item_%d", i), + []string{"red", "green", "blue"}[i%3], + float64(i)*10.5, + fmt.Sprintf(`{"key": "val_%d"}`, i), + ) + EnvNoError(s.t, env, err) + } + + // insert more child rows + for i := 11; i <= 15; i++ { + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "INSERT INTO %s (parent_id, value, tags) VALUES ($1, $2, $3)", + childTable), + (i%5)+1, + fmt.Sprintf("child_val_%d", i), + fmt.Sprintf("{tag_%d,common}", i), + ) + EnvNoError(s.t, env, err) + } + + // wait for CDC to replicate the new rows + EnvWaitFor(s.t, env, 3*time.Minute, "cdc parent rows", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+dstParent).Scan(&count) + return err == nil && count == 8 + }) + EnvWaitFor(s.t, env, 3*time.Minute, "cdc child rows", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+dstChild).Scan(&count) + return err == nil && count == 15 + }) + + // verify data integrity: compare actual row content + // query source and destination and compare + var srcParentCount, dstParentCount int64 + err = srcConn.QueryRow(s.t.Context(), "SELECT COUNT(*) FROM "+srcParent).Scan(&srcParentCount) + require.NoError(s.t, err) + err = dstConn.QueryRow(s.t.Context(), "SELECT COUNT(*) FROM "+dstParent).Scan(&dstParentCount) + require.NoError(s.t, err) + require.Equal(s.t, srcParentCount, dstParentCount, "parent table row counts should match") + + var srcChildCount, dstChildCount int64 + err = srcConn.QueryRow(s.t.Context(), "SELECT COUNT(*) FROM "+srcChild).Scan(&srcChildCount) + require.NoError(s.t, err) + err = dstConn.QueryRow(s.t.Context(), "SELECT COUNT(*) FROM "+dstChild).Scan(&dstChildCount) + require.NoError(s.t, err) + require.Equal(s.t, srcChildCount, dstChildCount, "child table row counts should match") + + env.Cancel(s.t.Context()) + RequireEnvCanceled(s.t, env) +} + +// Test_PG_Schema_Dump_No_Owner_No_Privileges verifies that the schema dump does +// not emit owner or grant statements that reference roles. We create a role on +// the source, give it ownership and grants on a table, then dump into a +// secondary cluster where that role does NOT exist. With --no-owner and +// --no-privileges the dump must succeed; without them it would fail on +// ALTER TABLE ... OWNER TO / GRANT ... TO . +// +// Also verifies initial load + CDC into the dumped table on the destination +// cluster, so we know the table is usable end-to-end. +func (s PeerFlowE2ETestSuitePG) Test_PG_Schema_Dump_No_Owner_No_Privileges() { + srcConn, srcPeerName, srcSchema := setupDedicatedPgDumpSource(s.t, s.suffix) + + dstCfg := internal.GetSecondaryPostgresConfigFromEnv() + + roleName := "peerdb_owner_role_" + s.suffix + tableName := "owned_tbl" + qualified := fmt.Sprintf("%s.%s", srcSchema, tableName) + + // destination connection (for assertions + role-absence sanity) + dstConnStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", + dstCfg.Host, dstCfg.Port, dstCfg.User, dstCfg.Password, dstCfg.Database) + dstConn, err := pgx.Connect(s.t.Context(), dstConnStr) + require.NoError(s.t, err, "failed to connect to secondary postgres on %s:%d (postgres2 tilt resource running?)", + dstCfg.Host, dstCfg.Port) + s.t.Cleanup(func() { dstConn.Close(s.t.Context()) }) + + // sanity: role must not exist on destination + var roleExistsOnDst bool + require.NoError(s.t, dstConn.QueryRow(s.t.Context(), + "SELECT EXISTS(SELECT 1 FROM pg_roles WHERE rolname=$1)", roleName).Scan(&roleExistsOnDst)) + require.False(s.t, roleExistsOnDst, "role %s unexpectedly exists on destination", roleName) + + // create role + owned/granted table on source with seed rows. + // the role is cluster-wide; the table lives in the dedicated source DB + // and will go away when that DB is dropped during cleanup. + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf("CREATE ROLE %s LOGIN PASSWORD 'pw'", roleName)) + require.NoError(s.t, err) + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf( + "CREATE TABLE %s (id SERIAL PRIMARY KEY, val TEXT)", qualified)) + require.NoError(s.t, err) + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf("ALTER TABLE %s OWNER TO %s", qualified, roleName)) + require.NoError(s.t, err) + _, err = srcConn.Exec(s.t.Context(), fmt.Sprintf("GRANT SELECT, INSERT ON %s TO %s", qualified, roleName)) + require.NoError(s.t, err) + + for i := 1; i <= 5; i++ { + _, err = srcConn.Exec(s.t.Context(), + fmt.Sprintf("INSERT INTO %s (val) VALUES ($1)", qualified), + fmt.Sprintf("snap_%d", i)) + require.NoError(s.t, err) + } + + // role is cluster-wide so it outlives the dedicated source DB; drop it + // against the shared source connection after the source DB is gone. + s.t.Cleanup(func() { + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cleanupCancel() + _, _ = s.Conn().Exec(cleanupCtx, "DROP ROLE IF EXISTS "+roleName) + + dropConn, err := pgx.Connect(cleanupCtx, dstConnStr) + if err != nil { + s.t.Logf("failed to connect to destination for cleanup: %v", err) + return + } + defer dropConn.Close(cleanupCtx) + _, _ = dropConn.Exec(cleanupCtx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", srcSchema)) + }) + + // register destination peer pointing at postgres2 + dstPeerName := "pgdump_noowner_dst_" + s.suffix + CreatePeer(s.t, &protos.Peer{ + Name: dstPeerName, + Type: protos.DBType_POSTGRES, + Config: &protos.Peer_PostgresConfig{PostgresConfig: dstCfg}, + }) + + config := &protos.FlowConnectionConfigs{ + FlowJobName: s.attachSuffix("test_pgdump_noowner"), + DestinationName: dstPeerName, + TableMappings: []*protos.TableMapping{{ + SourceTableIdentifier: qualified, + DestinationTableIdentifier: qualified, + }}, + SourceName: srcPeerName, + MaxBatchSize: 100, + DoInitialSnapshot: true, + System: protos.TypeSystem_PG, + Env: map[string]string{ + "PEERDB_PG_AUTOMATED_SCHEMA_DUMP": "true", + }, + } + + tc := NewTemporalClient(s.t) + env := ExecutePeerflow(s.t, tc, config) + SetupCDCFlowStatusQuery(s.t, env, config) + + // initial load: pg_dump must succeed despite missing owner/grantee role, + // then snapshot copies the 5 seed rows. + EnvWaitFor(s.t, env, 3*time.Minute, "initial load owned_tbl", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+qualified).Scan(&count) + return err == nil && count == 5 + }) + + // CDC: insert more rows on source and wait for them on dst + for i := 6; i <= 10; i++ { + _, err = srcConn.Exec(s.t.Context(), + fmt.Sprintf("INSERT INTO %s (val) VALUES ($1)", qualified), + fmt.Sprintf("cdc_%d", i)) + EnvNoError(s.t, env, err) + } + + EnvWaitFor(s.t, env, 3*time.Minute, "cdc owned_tbl", func() bool { + var count int64 + err := dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM "+qualified).Scan(&count) + return err == nil && count == 10 + }) + + // owner on dst should be the connecting user, not the (missing) source role + var dstOwner string + require.NoError(s.t, dstConn.QueryRow(s.t.Context(), + "SELECT tableowner FROM pg_tables WHERE schemaname=$1 AND tablename=$2", + srcSchema, tableName).Scan(&dstOwner)) + require.NotEqual(s.t, roleName, dstOwner, "destination table should not be owned by the source-only role") + + // no grants should reference the missing role on dst + var grantCount int + require.NoError(s.t, dstConn.QueryRow(s.t.Context(), + "SELECT COUNT(*) FROM information_schema.table_privileges WHERE table_schema=$1 AND table_name=$2 AND grantee=$3", + srcSchema, tableName, roleName).Scan(&grantCount)) + require.Zero(s.t, grantCount, "no privileges should be granted to the source-only role on destination") + + env.Cancel(s.t.Context()) + RequireEnvCanceled(s.t, env) +} diff --git a/flow/internal/dynamicconf.go b/flow/internal/dynamicconf.go index 351f83457..0702f4ddb 100644 --- a/flow/internal/dynamicconf.go +++ b/flow/internal/dynamicconf.go @@ -443,6 +443,15 @@ var DynamicSettings = [...]*protos.DynamicSetting{ ApplyMode: protos.DynconfApplyMode_APPLY_MODE_AFTER_RESUME, TargetForSetting: protos.DynconfTarget_ALL, }, + { + Name: "PEERDB_PG_AUTOMATED_SCHEMA_DUMP", + Description: "For PG-to-PG mirrors, run pg_dump --schema-only from source into psql on destination " + + "during setup so destination schema/tables/indexes match the source.", + DefaultValue: "false", + ValueType: protos.DynconfValueType_BOOL, + ApplyMode: protos.DynconfApplyMode_APPLY_MODE_AFTER_RESUME, + TargetForSetting: protos.DynconfTarget_POSTGRES, + }, } var DynamicIndex = func() map[string]int { @@ -799,3 +808,7 @@ func PeerDBMetricsRecordAggregatesEnabled(ctx context.Context, env map[string]st func PeerDBPostgresApplyCtidBlockPartitioning(ctx context.Context, env map[string]string) (bool, error) { return dynamicConfBool(ctx, env, "PEERDB_POSTGRES_APPLY_CTID_BLOCK_PARTITIONING_OVERRIDE") } + +func PeerDBPGAutomatedSchemaDump(ctx context.Context, env map[string]string) (bool, error) { + return dynamicConfBool(ctx, env, "PEERDB_PG_AUTOMATED_SCHEMA_DUMP") +} diff --git a/flow/internal/test_env.go b/flow/internal/test_env.go index 4055295f6..e4cac6e55 100644 --- a/flow/internal/test_env.go +++ b/flow/internal/test_env.go @@ -30,6 +30,16 @@ func GetAncillaryPostgresConfigFromEnv() *protos.PostgresConfig { } } +func GetSecondaryPostgresConfigFromEnv() *protos.PostgresConfig { + return &protos.PostgresConfig{ + Host: GetEnvString("PG2_HOST", "localhost"), + Port: uint32(getEnvUint[uint16]("PG2_PORT", 5437)), + User: GetEnvString("PG2_USER", "postgres"), + Password: GetEnvString("PG2_PASSWORD", "postgres"), + Database: GetEnvString("PG2_DATABASE", "postgres"), + } +} + func PostgresToxiproxyUpstreamHostWithFallback(fallback string) string { return GetEnvString("TOXIPROXY_POSTGRES_HOST", fallback) } diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 0f4aa64ec..31d164a14 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -5,6 +5,7 @@ import ( "log/slog" "maps" "slices" + "strings" "time" "go.temporal.io/sdk/log" @@ -239,6 +240,66 @@ func (s *SetupFlowExecution) createNormalizedTables( return nil } +// runPgDumpSchema runs pg_dump --schema-only on the source and pipes the output +// into psql on the destination, streaming the schema directly. +// This is only used for PG type system (PG-to-PG mirrors). +// Returns true only if the dump activity actually ran (it skips for SSH tunnel +// or non-password auth peers); callers must use this to decide whether the +// destination tables were created. +func (s *SetupFlowExecution) runPgDumpSchema( + ctx workflow.Context, + config *protos.FlowConnectionConfigsCore, +) (bool, error) { + s.Info("running pg_dump schema migration from source to destination") + + ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 1 * time.Hour, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 1 * time.Minute, + }, + }) + + input := &protos.RunPgDumpSchemaInput{ + SourceName: config.SourceName, + DestinationName: config.DestinationName, + FlowName: config.FlowJobName, + Env: config.Env, + } + + var ran bool + if err := workflow.ExecuteActivity(ctx, flowable.RunPgDumpSchema, input).Get(ctx, &ran); err != nil { + return false, fmt.Errorf("failed to run pg_dump schema migration: %w", err) + } + + return ran, nil +} + +// isTableAdditionChild reports whether this SetupFlow was launched as part of a +// table-addition child CDC flow. Such workflows are spawned with a parent +// workflow ID prefixed by "additional-cdc-flow-". +func isTableAdditionChild(ctx workflow.Context) bool { + parent := workflow.GetInfo(ctx).ParentWorkflowExecution + if parent == nil { + return false + } + return strings.HasPrefix(parent.ID, "additional-cdc-flow-") +} + +// getPGAutomatedSchemaDump checks the PEERDB_PG_AUTOMATED_SCHEMA_DUMP env flag via an activity. +func (s *SetupFlowExecution) getPGAutomatedSchemaDump(ctx workflow.Context, env map[string]string) bool { + checkCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: time.Minute, + }) + + var enabled bool + future := workflow.ExecuteActivity(checkCtx, flowable.PeerDBPGAutomatedSchemaDump, env) + if err := future.Get(checkCtx, &enabled); err != nil { + s.Warn("failed to check PEERDB_PG_AUTOMATED_SCHEMA_DUMP, defaulting to false", slog.Any("error", err)) + return false + } + return enabled +} + // executeSetupFlow executes the setup flow. func (s *SetupFlowExecution) executeSetupFlow( ctx workflow.Context, @@ -268,7 +329,23 @@ func (s *SetupFlowExecution) executeSetupFlow( return nil, fmt.Errorf("failed to fetch table schema: %w", err) } - if err := s.createNormalizedTables(ctx, config); err != nil { + // pg_dump silently no-ops for SSH tunnel / non-password-auth peers, so we + // only skip CreateNormalizedTable when the activity reports it actually ran. + // Skip pg_dump for resync (tables get _resync suffix and are swapped) and for + // table-addition child workflows (parent workflow ID prefix "additional-cdc-flow-"). + skipCreateTables := false + if config.System == protos.TypeSystem_PG && !config.Resync && !isTableAdditionChild(ctx) && + s.getPGAutomatedSchemaDump(ctx, config.Env) { + ran, err := s.runPgDumpSchema(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to run pg_dump schema migration: %w", err) + } + skipCreateTables = ran + } + + if skipCreateTables { + s.Info("skipping normalized table creation, pg_dump already created tables") + } else if err := s.createNormalizedTables(ctx, config); err != nil { return nil, fmt.Errorf("failed to create normalized tables: %w", err) } diff --git a/protos/flow.proto b/protos/flow.proto index ff2124510..006a588a4 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -583,6 +583,7 @@ enum DynconfTarget { SNOWFLAKE = 2; CLICKHOUSE = 3; QUEUES = 4; + POSTGRES = 5; } message DropFlowActivityInput { @@ -671,3 +672,10 @@ message GetFlowInfoToCancelFromCatalogOutput { string workflow_id = 2; peerdb_peers.DBType source_peer_type = 3; } + +message RunPgDumpSchemaInput { + string source_name = 1; + string destination_name = 2; + string flow_name = 3; + map env = 4; +}