Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "WIP: use less memory by downloading to sparse file" #10

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@ jobs:
go-version-file: go.mod
- run: script/build
- uses: ncipollo/release-action@v1
if: github.ref_type=='tag' && !contains(github.ref_name, '-')
if: ${{ startsWith(github.ref, 'refs/tags') }}
with:
artifacts: "pget"
- uses: ncipollo/release-action@v1
if: github.ref_type=='tag' && contains(github.ref_name, '-')
with:
artifacts: "pget"
prerelease: true

75 changes: 19 additions & 56 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package main

import (
"archive/tar"
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -32,26 +34,17 @@ func getRemoteFileSize(url string) (int64, error) {
return fileSize, nil
}

func downloadFile(url string, destFile *os.File, concurrency int) error {
func downloadFileToBuffer(url string, concurrency int) (*bytes.Buffer, error) {
fileSize, err := getRemoteFileSize(url)
if err != nil {
return err
}

if err != nil {
fmt.Printf("Error creating file: %v\n", err)
os.Exit(1)
}

err = destFile.Truncate(fileSize)
if err != nil {
return err
return nil, err
}

chunkSize := fileSize / int64(concurrency)
var wg sync.WaitGroup
wg.Add(concurrency)

data := make([]byte, fileSize)
errc := make(chan error, concurrency)
startTime := time.Now()

Expand All @@ -65,11 +58,6 @@ func downloadFile(url string, destFile *os.File, concurrency int) error {

go func(start, end int64) {
defer wg.Done()
fh, err := os.OpenFile(destFile.Name(), os.O_RDWR, 0644)
if err != nil {
errc <- fmt.Errorf("Failed to reopen file: %v", err)
}
defer fh.Close()

retries := 5
for retries > 0 {
Expand Down Expand Up @@ -97,22 +85,14 @@ func downloadFile(url string, destFile *os.File, concurrency int) error {
}
defer resp.Body.Close()

_, err = fh.Seek(start, 0)
if err != nil {
fmt.Printf("Error seeking in file: %v\n", err)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
continue
}

n, err := io.CopyN(fh, resp.Body, end-start+1)
n, err := io.ReadFull(resp.Body, data[start:end+1])
if err != nil && err != io.EOF {
fmt.Printf("Error reading response: %v\n", err)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
continue
}
if n != end-start+1 {
if n != int(end-start+1) {
fmt.Printf("Downloaded %d bytes instead of %d\n", n, end-start+1)
retries--
time.Sleep(time.Millisecond * 100) // wait 100 milliseconds before retrying
Expand All @@ -132,19 +112,20 @@ func downloadFile(url string, destFile *os.File, concurrency int) error {
close(errc) // close the error channel
for err := range errc {
if err != nil {
return err // return the first error we encounter
return nil, err // return the first error we encounter
}
}
elapsed := time.Since(startTime).Seconds()
througput := humanize.Bytes(uint64(float64(fileSize) / elapsed))
fmt.Printf("Downloaded %s bytes in %.3fs (%s/s)\n", humanize.Bytes(uint64(fileSize)), elapsed, througput)

return nil
buffer := bytes.NewBuffer(data)
return buffer, nil
}

func extractTarFile(input io.Reader, destDir string) error {
func extractTarFile(buffer *bytes.Buffer, destDir string) error {
startTime := time.Now()
tarReader := tar.NewReader(input)
tarReader := tar.NewReader(buffer)

for {
header, err := tarReader.Next()
Expand Down Expand Up @@ -201,7 +182,7 @@ func main() {
// check required positional arguments
args := flag.Args()
if len(args) < 2 {
fmt.Println("Usage: pcurl [-c concurrency] [-x] <url> <dest>")
fmt.Println("Usage: pcurl <url> <dest> [-c concurrency] [-x]")
os.Exit(1)
}

Expand All @@ -214,44 +195,26 @@ func main() {
os.Exit(1)
}

// create tempfile for downloading to
cwd, err := os.Getwd()
if err != nil {
fmt.Printf("Error getting cwd: %v\n", err)
os.Exit(1)
}
destTemp, err := os.CreateTemp(cwd, dest+".partial")
if err != nil {
fmt.Printf("Failed to create temp file: %v\n", err)
os.Exit(1)
}

err = downloadFile(url, destTemp, *concurrency)
buffer, err := downloadFileToBuffer(url, *concurrency)
if err != nil {
fmt.Printf("Error downloading file: %v\n", err)
os.Exit(1)
}

// extract the tar file if the -x flag was provided
if *extract {
_, err = destTemp.Seek(0, 0)
err = extractTarFile(buffer, dest)
if err != nil {
fmt.Printf("Error extracting tar file: %v\n", err)
os.Exit(1)
}
err = extractTarFile(destTemp, dest)
if err != nil {
fmt.Printf("Error extracting tar file: %v\n", err)
os.Exit(1)
}
destTemp.Close()
os.Remove(destTemp.Name())
} else {
// move destTemp to dest
err = os.Rename(destTemp.Name(), dest)
// if -x flag is not set, save the buffer to a file
err = ioutil.WriteFile(dest, buffer.Bytes(), 0644)
if err != nil {
fmt.Printf("Error moving downloaded file to correct location: %v\n", err)
fmt.Printf("Error writing file: %v\n", err)
os.Exit(1)
}
}

}