Commit ba2e8f65 authored by Bryan C. Mills's avatar Bryan C. Mills

cmd/go/internal/modfetch: make Repo.Zip write to an io.Writer instead of a temporary file

This will be used to eliminate a redundant copy in CL 145178.

(It also decouples two design points that were previously coupled: the
destination of the zip output and the program logic to write that
output.)

Updates #26794

Change-Id: I6cfd5a33c162c0016a1b83a278003684560a3772
Reviewed-on: https://go-review.googlesource.com/c/151341
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarRuss Cox <rsc@golang.org>
parent c124a919
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
...@@ -215,8 +216,8 @@ func (r *cachingRepo) GoMod(rev string) ([]byte, error) { ...@@ -215,8 +216,8 @@ func (r *cachingRepo) GoMod(rev string) ([]byte, error) {
return append([]byte(nil), c.text...), nil return append([]byte(nil), c.text...), nil
} }
func (r *cachingRepo) Zip(version, tmpdir string) (string, error) { func (r *cachingRepo) Zip(dst io.Writer, version string) error {
return r.r.Zip(version, tmpdir) return r.r.Zip(dst, version)
} }
// Stat is like Lookup(path).Stat(rev) but avoids the // Stat is like Lookup(path).Stat(rev) but avoids the
......
...@@ -407,25 +407,26 @@ func (r *codeRepo) modPrefix(rev string) string { ...@@ -407,25 +407,26 @@ func (r *codeRepo) modPrefix(rev string) string {
return r.modPath + "@" + rev return r.modPath + "@" + rev
} }
func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error) { func (r *codeRepo) Zip(dst io.Writer, version string) error {
rev, dir, _, err := r.findDir(version) rev, dir, _, err := r.findDir(version)
if err != nil { if err != nil {
return "", err return err
} }
dl, actualDir, err := r.code.ReadZip(rev, dir, codehost.MaxZipFile) dl, actualDir, err := r.code.ReadZip(rev, dir, codehost.MaxZipFile)
if err != nil { if err != nil {
return "", err return err
} }
defer dl.Close()
if actualDir != "" && !hasPathPrefix(dir, actualDir) { if actualDir != "" && !hasPathPrefix(dir, actualDir) {
return "", fmt.Errorf("internal error: downloading %v %v: dir=%q but actualDir=%q", r.path, rev, dir, actualDir) return fmt.Errorf("internal error: downloading %v %v: dir=%q but actualDir=%q", r.path, rev, dir, actualDir)
} }
subdir := strings.Trim(strings.TrimPrefix(dir, actualDir), "/") subdir := strings.Trim(strings.TrimPrefix(dir, actualDir), "/")
// Spool to local file. // Spool to local file.
f, err := ioutil.TempFile(tmpdir, "go-codehost-") f, err := ioutil.TempFile("", "go-codehost-")
if err != nil { if err != nil {
dl.Close() dl.Close()
return "", err return err
} }
defer os.Remove(f.Name()) defer os.Remove(f.Name())
defer f.Close() defer f.Close()
...@@ -433,35 +434,24 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -433,35 +434,24 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
lr := &io.LimitedReader{R: dl, N: maxSize + 1} lr := &io.LimitedReader{R: dl, N: maxSize + 1}
if _, err := io.Copy(f, lr); err != nil { if _, err := io.Copy(f, lr); err != nil {
dl.Close() dl.Close()
return "", err return err
} }
dl.Close() dl.Close()
if lr.N <= 0 { if lr.N <= 0 {
return "", fmt.Errorf("downloaded zip file too large") return fmt.Errorf("downloaded zip file too large")
} }
size := (maxSize + 1) - lr.N size := (maxSize + 1) - lr.N
if _, err := f.Seek(0, 0); err != nil { if _, err := f.Seek(0, 0); err != nil {
return "", err return err
} }
// Translate from zip file we have to zip file we want. // Translate from zip file we have to zip file we want.
zr, err := zip.NewReader(f, size) zr, err := zip.NewReader(f, size)
if err != nil { if err != nil {
return "", err return err
}
f2, err := ioutil.TempFile(tmpdir, "go-codezip-")
if err != nil {
return "", err
} }
zw := zip.NewWriter(f2) zw := zip.NewWriter(dst)
newName := f2.Name()
defer func() {
f2.Close()
if err != nil {
os.Remove(newName)
}
}()
if subdir != "" { if subdir != "" {
subdir += "/" subdir += "/"
} }
...@@ -472,12 +462,12 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -472,12 +462,12 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
if topPrefix == "" { if topPrefix == "" {
i := strings.Index(zf.Name, "/") i := strings.Index(zf.Name, "/")
if i < 0 { if i < 0 {
return "", fmt.Errorf("missing top-level directory prefix") return fmt.Errorf("missing top-level directory prefix")
} }
topPrefix = zf.Name[:i+1] topPrefix = zf.Name[:i+1]
} }
if !strings.HasPrefix(zf.Name, topPrefix) { if !strings.HasPrefix(zf.Name, topPrefix) {
return "", fmt.Errorf("zip file contains more than one top-level directory") return fmt.Errorf("zip file contains more than one top-level directory")
} }
dir, file := path.Split(zf.Name) dir, file := path.Split(zf.Name)
if file == "go.mod" { if file == "go.mod" {
...@@ -497,11 +487,12 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -497,11 +487,12 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
name = dir[:len(dir)-1] name = dir[:len(dir)-1]
} }
} }
for _, zf := range zr.File { for _, zf := range zr.File {
if topPrefix == "" { if topPrefix == "" {
i := strings.Index(zf.Name, "/") i := strings.Index(zf.Name, "/")
if i < 0 { if i < 0 {
return "", fmt.Errorf("missing top-level directory prefix") return fmt.Errorf("missing top-level directory prefix")
} }
topPrefix = zf.Name[:i+1] topPrefix = zf.Name[:i+1]
} }
...@@ -509,7 +500,7 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -509,7 +500,7 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
continue continue
} }
if !strings.HasPrefix(zf.Name, topPrefix) { if !strings.HasPrefix(zf.Name, topPrefix) {
return "", fmt.Errorf("zip file contains more than one top-level directory") return fmt.Errorf("zip file contains more than one top-level directory")
} }
name := strings.TrimPrefix(zf.Name, topPrefix) name := strings.TrimPrefix(zf.Name, topPrefix)
if !strings.HasPrefix(name, subdir) { if !strings.HasPrefix(name, subdir) {
...@@ -529,28 +520,28 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -529,28 +520,28 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
} }
base := path.Base(name) base := path.Base(name)
if strings.ToLower(base) == "go.mod" && base != "go.mod" { if strings.ToLower(base) == "go.mod" && base != "go.mod" {
return "", fmt.Errorf("zip file contains %s, want all lower-case go.mod", zf.Name) return fmt.Errorf("zip file contains %s, want all lower-case go.mod", zf.Name)
} }
if name == "LICENSE" { if name == "LICENSE" {
haveLICENSE = true haveLICENSE = true
} }
size := int64(zf.UncompressedSize) size := int64(zf.UncompressedSize64)
if size < 0 || maxSize < size { if size < 0 || maxSize < size {
return "", fmt.Errorf("module source tree too big") return fmt.Errorf("module source tree too big")
} }
maxSize -= size maxSize -= size
rc, err := zf.Open() rc, err := zf.Open()
if err != nil { if err != nil {
return "", err return err
} }
w, err := zw.Create(r.modPrefix(version) + "/" + name) w, err := zw.Create(r.modPrefix(version) + "/" + name)
lr := &io.LimitedReader{R: rc, N: size + 1} lr := &io.LimitedReader{R: rc, N: size + 1}
if _, err := io.Copy(w, lr); err != nil { if _, err := io.Copy(w, lr); err != nil {
return "", err return err
} }
if lr.N <= 0 { if lr.N <= 0 {
return "", fmt.Errorf("individual file too large") return fmt.Errorf("individual file too large")
} }
} }
...@@ -559,21 +550,15 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error ...@@ -559,21 +550,15 @@ func (r *codeRepo) Zip(version string, tmpdir string) (tmpfile string, err error
if err == nil { if err == nil {
w, err := zw.Create(r.modPrefix(version) + "/LICENSE") w, err := zw.Create(r.modPrefix(version) + "/LICENSE")
if err != nil { if err != nil {
return "", err return err
} }
if _, err := w.Write(data); err != nil { if _, err := w.Write(data); err != nil {
return "", err return err
} }
} }
} }
if err := zw.Close(); err != nil {
return "", err
}
if err := f2.Close(); err != nil {
return "", err
}
return f2.Name(), nil return zw.Close()
} }
// hasPathPrefix reports whether the path s begins with the // hasPathPrefix reports whether the path s begins with the
......
...@@ -391,7 +391,13 @@ func TestCodeRepo(t *testing.T) { ...@@ -391,7 +391,13 @@ func TestCodeRepo(t *testing.T) {
} }
} }
if tt.zip != nil || tt.ziperr != "" { if tt.zip != nil || tt.ziperr != "" {
zipfile, err := repo.Zip(tt.version, tmpdir) f, err := ioutil.TempFile(tmpdir, tt.version+".zip.")
if err != nil {
t.Fatalf("ioutil.TempFile: %v", err)
}
zipfile := f.Name()
err = repo.Zip(f, tt.version)
f.Close()
if err != nil { if err != nil {
if tt.ziperr != "" { if tt.ziperr != "" {
if err.Error() == tt.ziperr { if err.Error() == tt.ziperr {
......
...@@ -108,41 +108,47 @@ func downloadZip(mod module.Version, target string) error { ...@@ -108,41 +108,47 @@ func downloadZip(mod module.Version, target string) error {
if err != nil { if err != nil {
return err return err
} }
tmpfile, err := repo.Zip(mod.Version, os.TempDir()) tmpfile, err := ioutil.TempFile("", "go-codezip-")
if err != nil { if err != nil {
return err return err
} }
defer os.Remove(tmpfile) defer func() {
tmpfile.Close()
os.Remove(tmpfile.Name())
}()
if err := repo.Zip(tmpfile, mod.Version); err != nil {
return err
}
// Double-check zip file looks OK. // Double-check zip file looks OK.
z, err := zip.OpenReader(tmpfile) fi, err := tmpfile.Stat()
if err != nil {
return err
}
z, err := zip.NewReader(tmpfile, fi.Size())
if err != nil { if err != nil {
return err return err
} }
prefix := mod.Path + "@" + mod.Version + "/" prefix := mod.Path + "@" + mod.Version + "/"
for _, f := range z.File { for _, f := range z.File {
if !strings.HasPrefix(f.Name, prefix) { if !strings.HasPrefix(f.Name, prefix) {
z.Close()
return fmt.Errorf("zip for %s has unexpected file %s", prefix[:len(prefix)-1], f.Name) return fmt.Errorf("zip for %s has unexpected file %s", prefix[:len(prefix)-1], f.Name)
} }
} }
z.Close()
hash, err := dirhash.HashZip(tmpfile, dirhash.DefaultHash) hash, err := dirhash.HashZip(tmpfile.Name(), dirhash.DefaultHash)
if err != nil { if err != nil {
return err return err
} }
checkOneSum(mod, hash) // check before installing the zip file checkOneSum(mod, hash) // check before installing the zip file
r, err := os.Open(tmpfile) if _, err := tmpfile.Seek(0, io.SeekStart); err != nil {
if err != nil {
return err return err
} }
defer r.Close()
w, err := os.Create(target) w, err := os.Create(target)
if err != nil { if err != nil {
return err return err
} }
if _, err := io.Copy(w, r); err != nil { if _, err := io.Copy(w, tmpfile); err != nil {
w.Close() w.Close()
return fmt.Errorf("copying: %v", err) return fmt.Errorf("copying: %v", err)
} }
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/url" "net/url"
"os" "os"
"strings" "strings"
...@@ -209,39 +208,26 @@ func (p *proxyRepo) GoMod(version string) ([]byte, error) { ...@@ -209,39 +208,26 @@ func (p *proxyRepo) GoMod(version string) ([]byte, error) {
return data, nil return data, nil
} }
func (p *proxyRepo) Zip(version string, tmpdir string) (tmpfile string, err error) { func (p *proxyRepo) Zip(dst io.Writer, version string) error {
var body io.ReadCloser var body io.ReadCloser
encVer, err := module.EncodeVersion(version) encVer, err := module.EncodeVersion(version)
if err != nil { if err != nil {
return "", err return err
} }
err = webGetBody(p.url+"/@v/"+pathEscape(encVer)+".zip", &body) err = webGetBody(p.url+"/@v/"+pathEscape(encVer)+".zip", &body)
if err != nil { if err != nil {
return "", err return err
} }
defer body.Close() defer body.Close()
// Spool to local file. lr := &io.LimitedReader{R: body, N: codehost.MaxZipFile + 1}
f, err := ioutil.TempFile(tmpdir, "go-proxy-download-") if _, err := io.Copy(dst, lr); err != nil {
if err != nil { return err
return "", err
}
defer f.Close()
maxSize := int64(codehost.MaxZipFile)
lr := &io.LimitedReader{R: body, N: maxSize + 1}
if _, err := io.Copy(f, lr); err != nil {
os.Remove(f.Name())
return "", err
} }
if lr.N <= 0 { if lr.N <= 0 {
os.Remove(f.Name()) return fmt.Errorf("downloaded zip file too large")
return "", fmt.Errorf("downloaded zip file too large")
}
if err := f.Close(); err != nil {
os.Remove(f.Name())
return "", err
} }
return f.Name(), nil return nil
} }
// pathEscape escapes s so it can be used in a path. // pathEscape escapes s so it can be used in a path.
......
...@@ -6,8 +6,10 @@ package modfetch ...@@ -6,8 +6,10 @@ package modfetch
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strconv"
"time" "time"
"cmd/go/internal/cfg" "cmd/go/internal/cfg"
...@@ -45,11 +47,8 @@ type Repo interface { ...@@ -45,11 +47,8 @@ type Repo interface {
// GoMod returns the go.mod file for the given version. // GoMod returns the go.mod file for the given version.
GoMod(version string) (data []byte, err error) GoMod(version string) (data []byte, err error)
// Zip downloads a zip file for the given version // Zip writes a zip file for the given version to dst.
// to a new file in a given temporary directory. Zip(dst io.Writer, version string) error
// It returns the name of the new file.
// The caller should remove the file when finished with it.
Zip(version, tmpdir string) (tmpfile string, err error)
} }
// A Rev describes a single revision in a module repository. // A Rev describes a single revision in a module repository.
...@@ -357,7 +356,11 @@ func (l *loggingRepo) GoMod(version string) ([]byte, error) { ...@@ -357,7 +356,11 @@ func (l *loggingRepo) GoMod(version string) ([]byte, error) {
return l.r.GoMod(version) return l.r.GoMod(version)
} }
func (l *loggingRepo) Zip(version, tmpdir string) (string, error) { func (l *loggingRepo) Zip(dst io.Writer, version string) error {
defer logCall("Repo[%s]: Zip(%q, %q)", l.r.ModulePath(), version, tmpdir)() dstName := "_"
return l.r.Zip(version, tmpdir) if dst, ok := dst.(interface{ Name() string }); ok {
dstName = strconv.Quote(dst.Name())
}
defer logCall("Repo[%s]: Zip(%s, %q)", l.r.ModulePath(), dstName, version)()
return l.r.Zip(dst, version)
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment