diff --git a/core/app/service/script_library.go b/core/app/service/script_library.go index 1ae24607f950..78b6bb308ebb 100644 --- a/core/app/service/script_library.go +++ b/core/app/service/script_library.go @@ -233,7 +233,7 @@ func (u *ScriptService) Sync(req dto.OperateByTaskID) error { _ = os.MkdirAll(tmpDir, 0755) } scriptsUrl := fmt.Sprintf("%s/scripts/scripts.tar.gz", global.CONF.RemoteURL.ResourceURL) - err = files.DownloadFileWithProxy(scriptsUrl, tmpDir+"/scripts.tar.gz") + err = files.DownloadFileWithProxyStream(scriptsUrl, tmpDir+"/scripts.tar.gz") syncTask.LogWithStatus(i18n.GetMsgByKey("DownloadPackage"), err) if err != nil { return fmt.Errorf("download scripts.tar.gz failed, err: %v", err) diff --git a/core/app/service/upgrade.go b/core/app/service/upgrade.go index f625adaf6fa4..446268f9a18d 100644 --- a/core/app/service/upgrade.go +++ b/core/app/service/upgrade.go @@ -167,7 +167,7 @@ func (u *UpgradeService) Upgrade(req dto.Upgrade) error { _ = settingRepo.Update("SystemStatus", "Upgrading") go func() { oldLang := common.LoadParams("LANGUAGE") - if err := files.DownloadFileWithProxy(downloadPath+"/"+fileName, downloadDir+"/"+fileName); err != nil { + if err := files.DownloadFileWithProxyStream(downloadPath+"/"+fileName, downloadDir+"/"+fileName); err != nil { global.LOG.Errorf("download service file failed, err: %v", err) _ = settingRepo.Update("SystemStatus", "Free") return diff --git a/core/utils/files/files.go b/core/utils/files/files.go index eb3058c611c2..4f9363f89f08 100644 --- a/core/utils/files/files.go +++ b/core/utils/files/files.go @@ -1,7 +1,6 @@ package files import ( - "bytes" "crypto/md5" "encoding/hex" "errors" @@ -189,22 +188,44 @@ func DownloadFile(url, dst string) error { return nil } -func DownloadFileWithProxy(url, dst string) error { - _, resp, err := req_helper.HandleRequestWithProxy(url, http.MethodGet, constant.TimeOut5m) +func DownloadFileWithProxyStream(url, dst string) error { + resp, err := req_helper.HandleGetWithProxy(url) if err != nil { return err } + defer resp.Body.Close() + if resp.StatusCode >= http.StatusBadRequest { + return fmt.Errorf("download file [%s] failed, status code: %d", url, resp.StatusCode) + } - out, err := os.Create(dst) + tmpDst := dst + ".part" + _ = os.Remove(tmpDst) + out, err := os.Create(tmpDst) if err != nil { return fmt.Errorf("create download file [%s] error, err %s", dst, err.Error()) } - defer out.Close() + success := false + defer func() { + _ = out.Close() + if !success { + _ = os.Remove(tmpDst) + } + }() - reader := bytes.NewReader(resp) - if _, err = io.Copy(out, reader); err != nil { + n, err := io.Copy(out, resp.Body) + if err != nil { return fmt.Errorf("save download file [%s] error, err %s", dst, err.Error()) } + if resp.ContentLength > 0 && n != resp.ContentLength { + return fmt.Errorf("save download file [%s] error, content-length mismatch: expected %d, actual %d", dst, resp.ContentLength, n) + } + if err = out.Sync(); err != nil { + return fmt.Errorf("sync download file [%s] error, err %s", dst, err.Error()) + } + if err = os.Rename(tmpDst, dst); err != nil { + return fmt.Errorf("rename download file [%s] error, err %s", dst, err.Error()) + } + success = true return nil }