diff --git a/cli/check/check.go b/cli/check/check.go index e2ee906..08c51c6 100644 --- a/cli/check/check.go +++ b/cli/check/check.go @@ -105,7 +105,7 @@ func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfac Arch: p.PackageSpec.Arch, Ver: p.PackageSpec.Version, } - if err := install.FromRepo(ctx, pi, u, cache, rm, settings.Archs, cmd.dbOnly, downloader, db); err != nil { + if err := install.FromRepo(ctx, pi, u, cache, rm, settings.Archs, cmd.dbOnly, false, downloader, db); err != nil { logger.Errorf("Error installing %s.%s.%s: %v", pi.Name, pi.Arch, pi.Ver, err) exitCode = subcommands.ExitFailure continue diff --git a/cli/install/install.go b/cli/install/install.go index 3abd3ed..b6c53b4 100644 --- a/cli/install/install.go +++ b/cli/install/install.go @@ -43,12 +43,13 @@ type installCmd struct { dbOnly bool sources string dryRun bool + force bool } func (*installCmd) Name() string { return "install" } func (*installCmd) Synopsis() string { return "download and install a package and its dependencies" } func (*installCmd) Usage() string { - return fmt.Sprintf("%s install [-reinstall] [-sources repo1,repo2...] [-dry_run] ...\n", filepath.Base(os.Args[0])) + return fmt.Sprintf("%s install [-reinstall] [-sources repo1,repo2...] [-dry_run] [-force] ...\n", filepath.Base(os.Args[0])) } func (cmd *installCmd) SetFlags(f *flag.FlagSet) { @@ -57,6 +58,7 @@ func (cmd *installCmd) SetFlags(f *flag.FlagSet) { f.BoolVar(&cmd.dbOnly, "db_only", false, "only make changes to DB, don't perform install system actions") f.StringVar(&cmd.sources, "sources", "", "comma separated list of sources, setting this overrides local .repo files") f.BoolVar(&cmd.dryRun, "dry_run", false, "show what would be installed but do not install") + f.BoolVar(&cmd.force, "force", false, "force overwrite of conflicting files") } func (cmd *installCmd) Execute(ctx context.Context, flags *flag.FlagSet, _ ...any) subcommands.ExitStatus { @@ -91,6 +93,7 @@ func (cmd *installCmd) Execute(ctx context.Context, flags *flag.FlagSet, _ ...an confirm: settings.Confirm, downloader: downloader, dryRun: cmd.dryRun, + force: cmd.force, } // We only need to build sources and download indexes if there are any @@ -151,6 +154,7 @@ type installer struct { redownload bool // ignore cached downloads when reinstalling confirm bool // prompt before changes dryRun bool // show what would be done + force bool // force overwrite resulting from conflicts } // installFromFile installs a package from the specified file path. @@ -164,7 +168,7 @@ func (i *installer) installFromFile(path string) error { fmt.Printf("Not installing %s...\n", base) return nil } - if err := install.FromDisk(path, i.cache, i.dbOnly, i.shouldReinstall, i.db); err != nil { + if err := install.FromDisk(path, i.cache, i.dbOnly, i.force, i.shouldReinstall, i.db); err != nil { return fmt.Errorf("installing %s: %v", path, err) } return nil @@ -229,7 +233,7 @@ func (i *installer) installFromRepo(ctx context.Context, name string, archs []st fmt.Println("canceling install...") return nil } - if err := install.FromRepo(ctx, pi, r, i.cache, i.repoMap, archs, i.dbOnly, i.downloader, i.db); err != nil { + if err := install.FromRepo(ctx, pi, r, i.cache, i.repoMap, archs, i.dbOnly, i.force, i.downloader, i.db); err != nil { return fmt.Errorf("installing %s.%s.%s: %v", pi.Name, pi.Arch, pi.Ver, err) } @@ -251,7 +255,7 @@ func (i *installer) reinstall(ctx context.Context, pi goolib.PackageInfo, ps cli return nil } } - if err := install.Reinstall(ctx, ps, i.redownload, i.downloader); err != nil { + if err := install.Reinstall(ctx, ps, i.redownload, i.force, i.downloader, i.db); err != nil { return fmt.Errorf("error reinstalling %s, %v", pi.Name, err) } return nil diff --git a/cli/update/update.go b/cli/update/update.go index 23b9138..1883e4f 100644 --- a/cli/update/update.go +++ b/cli/update/update.go @@ -40,18 +40,20 @@ type updateCmd struct { dbOnly bool sources string dryRun bool + force bool } func (*updateCmd) Name() string { return "update" } func (*updateCmd) Synopsis() string { return "update all packages to the latest version available" } func (*updateCmd) Usage() string { - return fmt.Sprintf("%s update [-sources repo1,repo2...] [-dry_run]\n", filepath.Base(os.Args[0])) + return fmt.Sprintf("%s update [-sources repo1,repo2...] [-dry_run] [-force]\n", filepath.Base(os.Args[0])) } func (cmd *updateCmd) SetFlags(f *flag.FlagSet) { f.BoolVar(&cmd.dbOnly, "db_only", false, "only make changes to DB, don't perform install system actions") f.StringVar(&cmd.sources, "sources", "", "comma separated list of sources, setting this overrides local .repo files") f.BoolVar(&cmd.dryRun, "dry_run", false, "check for updates and print them, but do not prompt to install") + f.BoolVar(&cmd.force, "force", false, "force overwrite of conflicting files") } func (cmd *updateCmd) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { @@ -114,7 +116,7 @@ func (cmd *updateCmd) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interfa if err != nil { logger.Errorf("Error finding repo: %v.", err) } - if err := install.FromRepo(ctx, pi, r, cache, rm, settings.Archs, cmd.dbOnly, downloader, db); err != nil { + if err := install.FromRepo(ctx, pi, r, cache, rm, settings.Archs, cmd.dbOnly, cmd.force, downloader, db); err != nil { logger.Errorf("Error updating %s %s %s: %v", pi.Arch, pi.Name, pi.Ver, err) exitCode = subcommands.ExitFailure continue diff --git a/cli/verify/verify.go b/cli/verify/verify.go index 5364b05..97b813f 100644 --- a/cli/verify/verify.go +++ b/cli/verify/verify.go @@ -97,7 +97,7 @@ func (cmd *verifyCmd) Execute(ctx context.Context, flags *flag.FlagSet, _ ...int msg := fmt.Sprintf("Verification failed for %s, reinstalling...", pkg) logger.Info(msg) fmt.Println(msg) - if err := install.Reinstall(ctx, ps, false, downloader); err != nil { + if err := install.Reinstall(ctx, ps, false, false, downloader, db); err != nil { logger.Errorf("Error reinstalling %s, %v", pi.Name, err) } } else if !v { diff --git a/install/install.go b/install/install.go index 04e9133..98c02ef 100644 --- a/install/install.go +++ b/install/install.go @@ -129,7 +129,7 @@ func resolveReplacements(ctx context.Context, ps *goolib.PkgSpec, dbOnly bool, d return nil } -func installDeps(ctx context.Context, ps *goolib.PkgSpec, cache string, rm client.RepoMap, archs []string, dbOnly bool, downloader *client.Downloader, db *googetdb.GooDB) error { +func installDeps(ctx context.Context, ps *goolib.PkgSpec, cache string, rm client.RepoMap, archs []string, dbOnly, force bool, downloader *client.Downloader, db *googetdb.GooDB) error { logger.Infof("Resolving conflicts and dependencies for %s %s version %s", ps.Arch, ps.Name, ps.Version) if err := resolveConflicts(ps, db); err != nil { return err @@ -151,7 +151,7 @@ func installDeps(ctx context.Context, ps *goolib.PkgSpec, cache string, rm clien } logger.Infof("Dependency found: %s.%s %s (provides %s) is available", spec.Name, spec.Arch, spec.Version, pi.Name) - if err := FromRepo(ctx, goolib.PackageInfo{Name: spec.Name, Arch: spec.Arch, Ver: spec.Version}, repo, cache, rm, archs, dbOnly, downloader, db); err != nil { + if err := FromRepo(ctx, goolib.PackageInfo{Name: spec.Name, Arch: spec.Arch, Ver: spec.Version}, repo, cache, rm, archs, dbOnly, force, downloader, db); err != nil { return err } } @@ -159,7 +159,7 @@ func installDeps(ctx context.Context, ps *goolib.PkgSpec, cache string, rm clien } // FromRepo installs a package and all dependencies from a repository. -func FromRepo(ctx context.Context, pi goolib.PackageInfo, repo, cache string, rm client.RepoMap, archs []string, dbOnly bool, downloader *client.Downloader, db *googetdb.GooDB) error { +func FromRepo(ctx context.Context, pi goolib.PackageInfo, repo, cache string, rm client.RepoMap, archs []string, dbOnly, force bool, downloader *client.Downloader, db *googetdb.GooDB) error { logger.Infof("Starting install of %s.%s.%s", pi.Name, pi.Arch, pi.Ver) fmt.Printf("Installing %s.%s.%s and dependencies...\n", pi.Name, pi.Arch, pi.Ver) // When a specific version is requested, look for an exact match in the repository. @@ -182,7 +182,7 @@ func FromRepo(ctx context.Context, pi goolib.PackageInfo, repo, cache string, rm if err != nil { return err } - if err := installDeps(ctx, rs.PackageSpec, cache, rm, archs, dbOnly, downloader, db); err != nil { + if err := installDeps(ctx, rs.PackageSpec, cache, rm, archs, dbOnly, force, downloader, db); err != nil { return err } @@ -191,7 +191,7 @@ func FromRepo(ctx context.Context, pi goolib.PackageInfo, repo, cache string, rm return err } - insFiles, err := installPkg(dst, rs.PackageSpec, dbOnly) + insFiles, err := installPkg(dst, rs.PackageSpec, dbOnly, force, db) if err != nil { return err } @@ -216,7 +216,7 @@ func FromRepo(ctx context.Context, pi goolib.PackageInfo, repo, cache string, rm } // FromDisk installs a local .goo file. -func FromDisk(pkgPath, cache string, dbOnly, shouldReinstall bool, db *googetdb.GooDB) error { +func FromDisk(pkgPath, cache string, dbOnly, force, shouldReinstall bool, db *googetdb.GooDB) error { if _, err := oswrap.Stat(pkgPath); err != nil { return err } @@ -267,7 +267,7 @@ func FromDisk(pkgPath, cache string, dbOnly, shouldReinstall bool, db *googetdb. return err } - insFiles, err := installPkg(dst, zs, dbOnly) + insFiles, err := installPkg(dst, zs, dbOnly, force, db) if err != nil { return err } @@ -287,7 +287,7 @@ func FromDisk(pkgPath, cache string, dbOnly, shouldReinstall bool, db *googetdb. } // Reinstall reinstalls and optionally redownloads, a package. -func Reinstall(ctx context.Context, ps client.PackageState, rd bool, downloader *client.Downloader) error { +func Reinstall(ctx context.Context, ps client.PackageState, rd, force bool, downloader *client.Downloader, db *googetdb.GooDB) error { spec := ps.PackageSpec logger.Infof("Starting reinstall of %s.%s, version %s", spec.Name, spec.Arch, spec.Version) fmt.Printf("Reinstalling %s.%s %s and dependencies...\n", spec.Name, spec.Arch, spec.Version) @@ -326,7 +326,7 @@ func Reinstall(ctx context.Context, ps client.PackageState, rd bool, downloader } } - if _, err := installPkg(ps.LocalPath, ps.PackageSpec, false); err != nil { + if _, err := installPkg(ps.LocalPath, ps.PackageSpec, false, force, db); err != nil { return fmt.Errorf("error reinstalling package: %v", err) } @@ -394,12 +394,21 @@ func extractSpec(pkgPath string) (*goolib.PkgSpec, error) { return goolib.ExtractPkgSpec(f) } -func makeInstallFunction(src, dst string, insFiles map[string]string, dbOnly bool) func(string, os.FileInfo, error) error { +func makeInstallFunction(src, dst string, insFiles map[string]string, dbOnly, force bool, conflictMap map[string]string) func(string, os.FileInfo, error) error { return func(path string, fi os.FileInfo, err error) (outerr error) { if err != nil { return err } outPath := filepath.Join(dst, strings.TrimPrefix(path, src)) + + if owner, ok := conflictMap[outPath]; ok && !fi.IsDir() { + if !force { + return fmt.Errorf("file conflict: %s is already owned by package %s", outPath, owner) + } + logger.Infof("Warning: file conflict: %s is already owned by package %s, overwriting due to force flag", outPath, owner) + fmt.Printf("Warning: file conflict: %s is already owned by package %s, overwriting...\n", outPath, owner) + } + if dbOnly { if !fi.IsDir() { f, err := oswrap.Open(path) @@ -508,7 +517,26 @@ func cleanOldFiles(oldState client.PackageState, insFiles map[string]string) { } } -func installPkg(pkg string, ps *goolib.PkgSpec, dbOnly bool) (map[string]string, error) { +func buildConflictMap(db *googetdb.GooDB, currentPkg string) (map[string]string, error) { + conflictMap := make(map[string]string) + pkgs, err := db.FetchPkgs("") + if err != nil { + return nil, err + } + for _, p := range pkgs { + if p.PackageSpec == nil || p.PackageSpec.Name == currentPkg { + continue + } + for f, hash := range p.InstalledFiles { + if hash != "" { + conflictMap[f] = p.PackageSpec.Name + } + } + } + return conflictMap, nil +} + +func installPkg(pkg string, ps *goolib.PkgSpec, dbOnly, force bool, db *googetdb.GooDB) (map[string]string, error) { dir, err := download.ExtractPkg(pkg) if err != nil { return nil, err @@ -524,11 +552,16 @@ func installPkg(pkg string, ps *goolib.PkgSpec, dbOnly bool) (map[string]string, } }() + conflictMap, err := buildConflictMap(db, ps.Name) + if err != nil { + return nil, err + } + insFiles := make(map[string]string) for src, dst := range ps.Files { dst = resolveDst(dst) src = filepath.Join(dir, src) - if err := oswrap.Walk(src, makeInstallFunction(src, dst, insFiles, dbOnly)); err != nil { + if err := oswrap.Walk(src, makeInstallFunction(src, dst, insFiles, dbOnly, force, conflictMap)); err != nil { return nil, err } } diff --git a/install/install_test.go b/install/install_test.go index b9d6f67..b79d97d 100644 --- a/install/install_test.go +++ b/install/install_test.go @@ -142,6 +142,13 @@ func TestInstallPkg(t *testing.T) { } defer oswrap.RemoveAll(src) + settings.Initialize(t.TempDir(), false) + db, err := googetdb.NewDB(settings.DBFile()) + if err != nil { + t.Fatalf("googetdb.NewDB: %v", err) + } + defer db.Close() + dst, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("Failed to create temp directory: %v", err) @@ -201,7 +208,7 @@ func TestInstallPkg(t *testing.T) { } ps := goolib.PkgSpec{Files: map[string]string{"./": dst}} - got, err := installPkg(f.Name(), &ps, false) + got, err := installPkg(f.Name(), &ps, false, false, db) if err != nil { t.Fatalf("Error running installPkg: %v", err) } @@ -437,7 +444,7 @@ func TestFromRepo_SatisfiedByProvider(t *testing.T) { // We pass empty repo map and downloader because we expect it NOT to try downloading deps // since they are satisfied. - err = installDeps(t.Context(), ps, "", nil, nil, false, nil, db) + err = installDeps(t.Context(), ps, "", nil, nil, false, false, nil, db) if err != nil { t.Errorf("installDeps failed: %v", err) } @@ -479,7 +486,7 @@ func TestFromRepo_SatisfiedByUninstalledProvider(t *testing.T) { // Verify that dependency resolution succeeds (finding provider_pkg); the download // is expected to fail due to an invalid repository URL. downloader, _ := client.NewDownloader("") - err = installDeps(t.Context(), ps, "", rm, []string{"noarch"}, false, downloader, db) + err = installDeps(t.Context(), ps, "", rm, []string{"noarch"}, false, false, downloader, db) // We expect an error because download will fail (invalid URL/Source). if err == nil { @@ -495,3 +502,87 @@ func TestFromRepo_SatisfiedByUninstalledProvider(t *testing.T) { t.Logf("Got expected error (confirming resolution success): %v", err) } } + +func TestBuildConflictMap(t *testing.T) { + settings.Initialize(t.TempDir(), false) + state := []client.PackageState{ + { + PackageSpec: &goolib.PkgSpec{Name: "pkgA", Version: "1.0.0@1", Arch: "noarch"}, + InstalledFiles: map[string]string{ + "/path/to/file1": "chksum1", + "/path/to/dir1": "", + }, + }, + { + PackageSpec: &goolib.PkgSpec{Name: "pkgB", Version: "1.0.0@1", Arch: "noarch"}, + InstalledFiles: map[string]string{ + "/path/to/file2": "chksum2", + }, + }, + } + db, err := googetdb.NewDB(settings.DBFile()) + if err != nil { + t.Fatalf("googetdb.NewDB: %v", err) + } + defer db.Close() + if err := db.WriteStateToDB(state); err != nil { + t.Fatalf("WriteStateToDB: %v", err) + } + + cm, err := buildConflictMap(db, "pkgB") + if err != nil { + t.Fatalf("buildConflictMap: %v", err) + } + + if _, ok := cm["/path/to/dir1"]; ok { + t.Errorf("buildConflictMap included directory /path/to/dir1") + } + if owner, ok := cm["/path/to/file1"]; !ok || owner != "pkgA" { + t.Errorf("expected /path/to/file1 to map to pkgA, got %v", owner) + } + if _, ok := cm["/path/to/file2"]; ok { + t.Errorf("buildConflictMap included file from excluded package pkgB") + } +} + +func TestMakeInstallFunction(t *testing.T) { + dstDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer oswrap.RemoveAll(dstDir) + + srcDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + srcDir = filepath.Join(srcDir, "foo") // append subdirectory to properly test TrimPrefix + oswrap.MkdirAll(srcDir, 0755) + defer oswrap.RemoveAll(srcDir) + + conflictPath := filepath.Join(dstDir, "conflicting_file") + cm := map[string]string{ + conflictPath: "pkgOwner", + } + + f, err := oswrap.Create(filepath.Join(srcDir, "conflicting_file")) + if err != nil { + t.Fatal(err) + } + fi, _ := f.Stat() + f.Close() + + // Test 1: Conflict without force -> Error + fnBlock := makeInstallFunction(srcDir, dstDir, make(map[string]string), false, false, cm) + errBlock := fnBlock(filepath.Join(srcDir, "conflicting_file"), fi, nil) + if errBlock == nil { + t.Errorf("expected conflict error, got nil") + } + + // Test 2: Conflict with force -> Success + fnForce := makeInstallFunction(srcDir, dstDir, make(map[string]string), false, true, cm) + errForce := fnForce(filepath.Join(srcDir, "conflicting_file"), fi, nil) + if errForce != nil { + t.Errorf("expected no error with force, got %v", errForce) + } +}