Commit 77457e8c authored by Jacob Vosmaer's avatar Jacob Vosmaer

Push uploader control flow into objectstore package

parent 07bbad76
---
title: Push uploader control flow into objectstore package
merge_request: 608
author:
type: other
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"strconv" "strconv"
"time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
...@@ -98,38 +99,28 @@ func (fh *FileHandler) GitLabFinalizeFields(prefix string) (map[string]string, e ...@@ -98,38 +99,28 @@ func (fh *FileHandler) GitLabFinalizeFields(prefix string) (map[string]string, e
return data, nil return data, nil
} }
// Upload represents a destination where we store an upload type consumer interface {
type uploadWriter interface { Consume(context.Context, io.Reader, time.Time) (int64, error)
io.WriteCloser
CloseWithError(error) error
ETag() string
} }
// SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done // SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done
// Make sure the provided context will not expire before finalizing upload with GitLab Rails. // Make sure the provided context will not expire before finalizing upload with GitLab Rails.
func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) { func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) {
var uploadWriter uploadWriter var uploadDestination consumer
fh = &FileHandler{ fh = &FileHandler{
Name: opts.TempFilePrefix, Name: opts.TempFilePrefix,
RemoteID: opts.RemoteID, RemoteID: opts.RemoteID,
RemoteURL: opts.RemoteURL, RemoteURL: opts.RemoteURL,
} }
hashes := newMultiHash() hashes := newMultiHash()
writers := []io.Writer{hashes.Writer} reader = io.TeeReader(reader, hashes.Writer)
defer func() {
for _, w := range writers {
if closer, ok := w.(io.WriteCloser); ok {
closer.Close()
}
}
}()
var clientMode string var clientMode string
switch { switch {
case opts.IsLocal(): case opts.IsLocal():
clientMode = "local" clientMode = "local"
uploadWriter, err = fh.uploadLocalFile(ctx, opts) uploadDestination, err = fh.uploadLocalFile(ctx, opts)
case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud(): case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud():
clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider) clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider)
p := &objectstore.GoCloudObjectParams{ p := &objectstore.GoCloudObjectParams{
...@@ -137,38 +128,31 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts ...@@ -137,38 +128,31 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts
Mux: opts.ObjectStorageConfig.URLMux, Mux: opts.ObjectStorageConfig.URLMux,
BucketURL: opts.ObjectStorageConfig.GoCloudConfig.URL, BucketURL: opts.ObjectStorageConfig.GoCloudConfig.URL,
ObjectName: opts.RemoteTempObjectID, ObjectName: opts.RemoteTempObjectID,
Deadline: opts.Deadline,
} }
uploadWriter, err = objectstore.NewGoCloudObject(p) uploadDestination, err = objectstore.NewGoCloudObject(p)
case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid(): case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid():
clientMode = "s3" clientMode = "s3"
uploadWriter, err = objectstore.NewS3Object( uploadDestination, err = objectstore.NewS3Object(
ctx,
opts.RemoteTempObjectID, opts.RemoteTempObjectID,
opts.ObjectStorageConfig.S3Credentials, opts.ObjectStorageConfig.S3Credentials,
opts.ObjectStorageConfig.S3Config, opts.ObjectStorageConfig.S3Config,
opts.Deadline,
) )
case opts.IsMultipart(): case opts.IsMultipart():
clientMode = "multipart" clientMode = "multipart"
uploadWriter, err = objectstore.NewMultipart( uploadDestination, err = objectstore.NewMultipart(
ctx,
opts.PresignedParts, opts.PresignedParts,
opts.PresignedCompleteMultipart, opts.PresignedCompleteMultipart,
opts.PresignedAbortMultipart, opts.PresignedAbortMultipart,
opts.PresignedDelete, opts.PresignedDelete,
opts.PutHeaders, opts.PutHeaders,
opts.Deadline,
opts.PartSize, opts.PartSize,
) )
default: default:
clientMode = "http" clientMode = "http"
uploadWriter, err = objectstore.NewObject( uploadDestination, err = objectstore.NewObject(
ctx,
opts.PresignedPut, opts.PresignedPut,
opts.PresignedDelete, opts.PresignedDelete,
opts.PutHeaders, opts.PutHeaders,
opts.Deadline,
size, size,
) )
} }
...@@ -177,34 +161,22 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts ...@@ -177,34 +161,22 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts
return nil, err return nil, err
} }
writers = append(writers, uploadWriter)
defer func() {
if err != nil {
uploadWriter.CloseWithError(err)
}
}()
if opts.MaximumSize > 0 { if opts.MaximumSize > 0 {
if size > opts.MaximumSize { if size > opts.MaximumSize {
return nil, SizeError(fmt.Errorf("the upload size %d is over maximum of %d bytes", size, opts.MaximumSize)) return nil, SizeError(fmt.Errorf("the upload size %d is over maximum of %d bytes", size, opts.MaximumSize))
} }
// We allow to read an extra byte to check later if we exceed the max size reader = &hardLimitReader{r: reader, n: opts.MaximumSize}
reader = &io.LimitedReader{R: reader, N: opts.MaximumSize + 1}
} }
multiWriter := io.MultiWriter(writers...) fh.Size, err = uploadDestination.Consume(ctx, reader, opts.Deadline)
fh.Size, err = io.Copy(multiWriter, reader)
if err != nil { if err != nil {
if err == objectstore.ErrNotEnoughParts {
err = ErrEntityTooLarge
}
return nil, err return nil, err
} }
if opts.MaximumSize > 0 && fh.Size > opts.MaximumSize {
// An extra byte was read thus exceeding the max size
return nil, ErrEntityTooLarge
}
if size != -1 && size != fh.Size { if size != -1 && size != fh.Size {
return nil, SizeError(fmt.Errorf("expected %d bytes but got only %d", size, fh.Size)) return nil, SizeError(fmt.Errorf("expected %d bytes but got only %d", size, fh.Size))
} }
...@@ -226,25 +198,11 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts ...@@ -226,25 +198,11 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts
} }
logger.Info("saved file") logger.Info("saved file")
fh.hashes = hashes.finish() fh.hashes = hashes.finish()
return fh, nil
// we need to close the writer in order to get ETag header
err = uploadWriter.Close()
if err != nil {
if err == objectstore.ErrNotEnoughParts {
return nil, ErrEntityTooLarge
}
return nil, err
}
etag := uploadWriter.ETag()
fh.hashes["etag"] = etag
return fh, err
} }
func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (uploadWriter, error) { func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (consumer, error) {
// make sure TempFolder exists // make sure TempFolder exists
err := os.MkdirAll(opts.LocalTempPath, 0700) err := os.MkdirAll(opts.LocalTempPath, 0700)
if err != nil { if err != nil {
...@@ -262,13 +220,19 @@ func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) ...@@ -262,13 +220,19 @@ func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts)
}() }()
fh.LocalPath = file.Name() fh.LocalPath = file.Name()
return &nopUpload{file}, nil return &localUpload{file}, nil
} }
type nopUpload struct{ io.WriteCloser } type localUpload struct{ io.WriteCloser }
func (nop *nopUpload) CloseWithError(error) error { return nop.Close() } func (loc *localUpload) Consume(_ context.Context, r io.Reader, _ time.Time) (int64, error) {
func (nop *nopUpload) ETag() string { return "" } n, err := io.Copy(loc.WriteCloser, r)
errClose := loc.Close()
if err == nil {
err = errClose
}
return n, err
}
// SaveFileFromDisk open the local file fileName and calls SaveFileFromReader // SaveFileFromDisk open the local file fileName and calls SaveFileFromReader
func SaveFileFromDisk(ctx context.Context, fileName string, opts *SaveFileOpts) (fh *FileHandler, err error) { func SaveFileFromDisk(ctx context.Context, fileName string, opts *SaveFileOpts) (fh *FileHandler, err error) {
......
...@@ -413,5 +413,4 @@ func checkFileHandlerWithFields(t *testing.T, fh *filestore.FileHandler, fields ...@@ -413,5 +413,4 @@ func checkFileHandlerWithFields(t *testing.T, fh *filestore.FileHandler, fields
require.Equal(t, test.ObjectSHA1, fields[key("sha1")]) require.Equal(t, test.ObjectSHA1, fields[key("sha1")])
require.Equal(t, test.ObjectSHA256, fields[key("sha256")]) require.Equal(t, test.ObjectSHA256, fields[key("sha256")])
require.Equal(t, test.ObjectSHA512, fields[key("sha512")]) require.Equal(t, test.ObjectSHA512, fields[key("sha512")])
require.Contains(t, fields, key("etag"))
} }
package filestore
import "io"
type hardLimitReader struct {
r io.Reader
n int64
}
func (h *hardLimitReader) Read(p []byte) (int, error) {
nRead, err := h.r.Read(p)
h.n -= int64(nRead)
if h.n < 0 {
err = ErrEntityTooLarge
}
return nRead, err
}
package filestore
import (
"fmt"
"io/ioutil"
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/require"
)
func TestHardLimitReader(t *testing.T) {
const text = "hello world"
r := iotest.OneByteReader(
&hardLimitReader{
r: strings.NewReader(text),
n: int64(len(text)),
},
)
out, err := ioutil.ReadAll(r)
require.NoError(t, err)
require.Equal(t, text, string(out))
}
func TestHardLimitReaderFail(t *testing.T) {
const text = "hello world"
for bufSize := len(text) / 2; bufSize < len(text)*2; bufSize++ {
t.Run(fmt.Sprintf("bufsize:%d", bufSize), func(t *testing.T) {
r := &hardLimitReader{
r: iotest.DataErrReader(strings.NewReader(text)),
n: int64(len(text)) - 1,
}
buf := make([]byte, bufSize)
var err error
for i := 0; err == nil && i < 1000; i++ {
_, err = r.Read(buf)
}
require.Equal(t, ErrEntityTooLarge, err)
})
}
}
...@@ -15,7 +15,7 @@ type GoCloudObject struct { ...@@ -15,7 +15,7 @@ type GoCloudObject struct {
mux *blob.URLMux mux *blob.URLMux
bucketURL string bucketURL string
objectName string objectName string
uploader *uploader
} }
type GoCloudObjectParams struct { type GoCloudObjectParams struct {
...@@ -23,7 +23,6 @@ type GoCloudObjectParams struct { ...@@ -23,7 +23,6 @@ type GoCloudObjectParams struct {
Mux *blob.URLMux Mux *blob.URLMux
BucketURL string BucketURL string
ObjectName string ObjectName string
Deadline time.Time
} }
func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) { func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) {
...@@ -40,8 +39,6 @@ func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) { ...@@ -40,8 +39,6 @@ func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) {
} }
o.uploader = newUploader(o) o.uploader = newUploader(o)
o.Execute(p.Ctx, p.Deadline)
return o, nil return o, nil
} }
......
...@@ -3,7 +3,6 @@ package objectstore_test ...@@ -3,7 +3,6 @@ package objectstore_test
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -24,20 +23,16 @@ func TestGoCloudObjectUpload(t *testing.T) { ...@@ -24,20 +23,16 @@ func TestGoCloudObjectUpload(t *testing.T) {
objectName := "test.png" objectName := "test.png"
testURL := "azuretest://azure.example.com/test-container" testURL := "azuretest://azure.example.com/test-container"
p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName, Deadline: deadline} p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName}
object, err := objectstore.NewGoCloudObject(p) object, err := objectstore.NewGoCloudObject(p)
require.NotNil(t, object) require.NotNil(t, object)
require.NoError(t, err) require.NoError(t, err)
// copy data // copy data
n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
// close HTTP stream
err = object.Close()
require.NoError(t, err)
bucket, err := mux.OpenBucket(ctx, testURL) bucket, err := mux.OpenBucket(ctx, testURL)
require.NoError(t, err) require.NoError(t, err)
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
"time"
"gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/labkit/mask" "gitlab.com/gitlab-org/labkit/mask"
...@@ -33,13 +32,13 @@ type Multipart struct { ...@@ -33,13 +32,13 @@ type Multipart struct {
partSize int64 partSize int64
etag string etag string
uploader *uploader
} }
// NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes // NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes
// then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent. // then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent.
// In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources // In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources
func NewMultipart(ctx context.Context, partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, deadline time.Time, partSize int64) (*Multipart, error) { func NewMultipart(partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, partSize int64) (*Multipart, error) {
m := &Multipart{ m := &Multipart{
PartURLs: partURLs, PartURLs: partURLs,
CompleteURL: completeURL, CompleteURL: completeURL,
...@@ -50,8 +49,6 @@ func NewMultipart(ctx context.Context, partURLs []string, completeURL, abortURL, ...@@ -50,8 +49,6 @@ func NewMultipart(ctx context.Context, partURLs []string, completeURL, abortURL,
} }
m.uploader = newUploader(m) m.uploader = newUploader(m)
m.Execute(ctx, deadline)
return m, nil return m, nil
} }
...@@ -109,7 +106,7 @@ func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, pu ...@@ -109,7 +106,7 @@ func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, pu
n, err := io.Copy(file, src) n, err := io.Copy(file, src)
if err != nil { if err != nil {
return nil, fmt.Errorf("write part %d to disk: %v", partNumber, err) return nil, err
} }
if n == 0 { if n == 0 {
return nil, nil return nil, nil
...@@ -132,18 +129,15 @@ func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[stri ...@@ -132,18 +129,15 @@ func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[stri
return "", fmt.Errorf("missing deadline") return "", fmt.Errorf("missing deadline")
} }
part, err := newObject(ctx, url, "", headers, deadline, size, false) part, err := newObject(url, "", headers, size, false)
if err != nil { if err != nil {
return "", err return "", err
} }
_, err = io.CopyN(part, body, size) if n, err := part.Consume(ctx, io.LimitReader(body, size), deadline); err != nil || n < size {
if err != nil { if err == nil {
return "", err err = io.ErrUnexpectedEOF
} }
err = part.Close()
if err != nil {
return "", err return "", err
} }
......
...@@ -48,19 +48,17 @@ func TestMultipartUploadWithUpcaseETags(t *testing.T) { ...@@ -48,19 +48,17 @@ func TestMultipartUploadWithUpcaseETags(t *testing.T) {
deadline := time.Now().Add(testTimeout) deadline := time.Now().Add(testTimeout)
m, err := objectstore.NewMultipart(ctx, m, err := objectstore.NewMultipart(
[]string{ts.URL}, // a single presigned part URL []string{ts.URL}, // a single presigned part URL
ts.URL, // the complete multipart upload URL ts.URL, // the complete multipart upload URL
"", // no abort "", // no abort
"", // no delete "", // no delete
map[string]string{}, // no custom headers map[string]string{}, // no custom headers
deadline, test.ObjectSize) // parts size equal to the whole content. Only 1 part
test.ObjectSize) // parts size equal to the whole content. Only 1 part
require.NoError(t, err) require.NoError(t, err)
_, err = m.Write([]byte(test.ObjectContent)) _, err = m.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, m.Close())
require.Equal(t, 1, putCnt, "1 part expected") require.Equal(t, 1, putCnt, "1 part expected")
require.Equal(t, 1, postCnt, "1 complete multipart upload expected") require.Equal(t, 1, postCnt, "1 complete multipart upload expected")
} }
...@@ -47,17 +47,17 @@ type Object struct { ...@@ -47,17 +47,17 @@ type Object struct {
etag string etag string
metrics bool metrics bool
uploader *uploader
} }
type StatusCodeError error type StatusCodeError error
// NewObject opens an HTTP connection to Object Store and returns an Object pointer that can be used for uploading. // NewObject opens an HTTP connection to Object Store and returns an Object pointer that can be used for uploading.
func NewObject(ctx context.Context, putURL, deleteURL string, putHeaders map[string]string, deadline time.Time, size int64) (*Object, error) { func NewObject(putURL, deleteURL string, putHeaders map[string]string, size int64) (*Object, error) {
return newObject(ctx, putURL, deleteURL, putHeaders, deadline, size, true) return newObject(putURL, deleteURL, putHeaders, size, true)
} }
func newObject(ctx context.Context, putURL, deleteURL string, putHeaders map[string]string, deadline time.Time, size int64, metrics bool) (*Object, error) { func newObject(putURL, deleteURL string, putHeaders map[string]string, size int64, metrics bool) (*Object, error) {
o := &Object{ o := &Object{
putURL: putURL, putURL: putURL,
deleteURL: deleteURL, deleteURL: deleteURL,
...@@ -66,9 +66,7 @@ func newObject(ctx context.Context, putURL, deleteURL string, putHeaders map[str ...@@ -66,9 +66,7 @@ func newObject(ctx context.Context, putURL, deleteURL string, putHeaders map[str
metrics: metrics, metrics: metrics,
} }
o.uploader = newMD5Uploader(o, metrics) o.uploader = newETagCheckUploader(o, metrics)
o.Execute(ctx, deadline)
return o, nil return o, nil
} }
......
...@@ -35,18 +35,14 @@ func testObjectUploadNoErrors(t *testing.T, startObjectStore osFactory, useDelet ...@@ -35,18 +35,14 @@ func testObjectUploadNoErrors(t *testing.T, startObjectStore osFactory, useDelet
defer cancel() defer cancel()
deadline := time.Now().Add(testTimeout) deadline := time.Now().Add(testTimeout)
object, err := objectstore.NewObject(ctx, objectURL, deleteURL, putHeaders, deadline, test.ObjectSize) object, err := objectstore.NewObject(objectURL, deleteURL, putHeaders, test.ObjectSize)
require.NoError(t, err) require.NoError(t, err)
// copy data // copy data
n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
// close HTTP stream
err = object.Close()
require.NoError(t, err)
require.Equal(t, contentType, osStub.GetHeader(test.ObjectPath, "Content-Type")) require.Equal(t, contentType, osStub.GetHeader(test.ObjectPath, "Content-Type"))
// Checking MD5 extraction // Checking MD5 extraction
...@@ -107,12 +103,10 @@ func TestObjectUpload404(t *testing.T) { ...@@ -107,12 +103,10 @@ func TestObjectUpload404(t *testing.T) {
deadline := time.Now().Add(testTimeout) deadline := time.Now().Add(testTimeout)
objectURL := ts.URL + test.ObjectPath objectURL := ts.URL + test.ObjectPath
object, err := objectstore.NewObject(ctx, objectURL, "", map[string]string{}, deadline, test.ObjectSize) object, err := objectstore.NewObject(objectURL, "", map[string]string{}, test.ObjectSize)
require.NoError(t, err) require.NoError(t, err)
_, err = io.Copy(object, strings.NewReader(test.ObjectContent)) _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err)
err = object.Close()
require.Error(t, err) require.Error(t, err)
_, isStatusCodeError := err.(objectstore.StatusCodeError) _, isStatusCodeError := err.(objectstore.StatusCodeError)
require.True(t, isStatusCodeError, "Should fail with StatusCodeError") require.True(t, isStatusCodeError, "Should fail with StatusCodeError")
...@@ -152,13 +146,10 @@ func TestObjectUploadBrokenConnection(t *testing.T) { ...@@ -152,13 +146,10 @@ func TestObjectUploadBrokenConnection(t *testing.T) {
deadline := time.Now().Add(testTimeout) deadline := time.Now().Add(testTimeout)
objectURL := ts.URL + test.ObjectPath objectURL := ts.URL + test.ObjectPath
object, err := objectstore.NewObject(ctx, objectURL, "", map[string]string{}, deadline, -1) object, err := objectstore.NewObject(objectURL, "", map[string]string{}, -1)
require.NoError(t, err) require.NoError(t, err)
_, copyErr := io.Copy(object, &endlessReader{}) _, copyErr := object.Consume(ctx, &endlessReader{}, deadline)
require.Error(t, copyErr) require.Error(t, copyErr)
require.NotEqual(t, io.ErrClosedPipe, copyErr, "We are shadowing the real error") require.NotEqual(t, io.ErrClosedPipe, copyErr, "We are shadowing the real error")
closeErr := object.Close()
require.Equal(t, copyErr, closeErr)
} }
...@@ -19,10 +19,10 @@ type S3Object struct { ...@@ -19,10 +19,10 @@ type S3Object struct {
objectName string objectName string
uploaded bool uploaded bool
uploader *uploader
} }
func NewS3Object(ctx context.Context, objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config, deadline time.Time) (*S3Object, error) { func NewS3Object(objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config) (*S3Object, error) {
o := &S3Object{ o := &S3Object{
credentials: s3Credentials, credentials: s3Credentials,
config: s3Config, config: s3Config,
...@@ -30,8 +30,6 @@ func NewS3Object(ctx context.Context, objectName string, s3Credentials config.S3 ...@@ -30,8 +30,6 @@ func NewS3Object(ctx context.Context, objectName string, s3Credentials config.S3
} }
o.uploader = newUploader(o) o.uploader = newUploader(o)
o.Execute(ctx, deadline)
return o, nil return o, nil
} }
......
...@@ -3,7 +3,6 @@ package objectstore_test ...@@ -3,7 +3,6 @@ package objectstore_test
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
...@@ -44,18 +43,14 @@ func TestS3ObjectUpload(t *testing.T) { ...@@ -44,18 +43,14 @@ func TestS3ObjectUpload(t *testing.T) {
objectName := filepath.Join(tmpDir, "s3-test-data") objectName := filepath.Join(tmpDir, "s3-test-data")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) object, err := objectstore.NewS3Object(objectName, creds, config)
require.NoError(t, err) require.NoError(t, err)
// copy data // copy data
n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
// close HTTP stream
err = object.Close()
require.NoError(t, err)
test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent)
test.CheckS3Metadata(t, sess, config, objectName) test.CheckS3Metadata(t, sess, config, objectName)
...@@ -107,17 +102,14 @@ func TestConcurrentS3ObjectUpload(t *testing.T) { ...@@ -107,17 +102,14 @@ func TestConcurrentS3ObjectUpload(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) object, err := objectstore.NewS3Object(objectName, creds, config)
require.NoError(t, err) require.NoError(t, err)
// copy data // copy data
n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch")
// close HTTP stream
require.NoError(t, object.Close())
test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent)
wg.Done() wg.Done()
}(i) }(i)
...@@ -139,7 +131,7 @@ func TestS3ObjectUploadCancel(t *testing.T) { ...@@ -139,7 +131,7 @@ func TestS3ObjectUploadCancel(t *testing.T) {
objectName := filepath.Join(tmpDir, "s3-test-data") objectName := filepath.Join(tmpDir, "s3-test-data")
object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) object, err := objectstore.NewS3Object(objectName, creds, config)
require.NoError(t, err) require.NoError(t, err)
...@@ -147,6 +139,6 @@ func TestS3ObjectUploadCancel(t *testing.T) { ...@@ -147,6 +139,6 @@ func TestS3ObjectUploadCancel(t *testing.T) {
// we handle this gracefully. // we handle this gracefully.
cancel() cancel()
_, err = io.Copy(object, strings.NewReader(test.ObjectContent)) _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline)
require.Error(t, err) require.Error(t, err)
} }
...@@ -8,177 +8,89 @@ import ( ...@@ -8,177 +8,89 @@ import (
"hash" "hash"
"io" "io"
"strings" "strings"
"sync"
"time" "time"
"gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/labkit/log"
) )
// uploader is an io.WriteCloser that can be used as write end of the uploading pipe. // uploader consumes an io.Reader and uploads it using a pluggable uploadStrategy.
type uploader struct { type uploader struct {
// etag is the object storage provided checksum
etag string
// md5 is an optional hasher for calculating md5 on the fly
md5 hash.Hash
w io.Writer
// uploadError is the last error occourred during upload
uploadError error
// ctx is the internal context bound to the upload request
ctx context.Context
pr *io.PipeReader
pw *io.PipeWriter
strategy uploadStrategy strategy uploadStrategy
metrics bool
// closeOnce is used to prevent multiple calls to pw.Close // In the case of S3 uploads, we have a multipart upload which
// which may result to Close overriding the error set by CloseWithError // instantiates uploads for the individual parts. We don't want to
// Bug fixed in v1.14: https://github.com/golang/go/commit/f45eb9ff3c96dfd951c65d112d033ed7b5e02432 // increment metrics for the individual parts, so that is why we have
closeOnce sync.Once // this boolean flag.
} metrics bool
func newUploader(strategy uploadStrategy) uploader { // With S3 we compare the MD5 of the data we sent with the ETag returned
pr, pw := io.Pipe() // by the object storage server.
return uploader{w: pw, pr: pr, pw: pw, strategy: strategy, metrics: true} checkETag bool
} }
func newMD5Uploader(strategy uploadStrategy, metrics bool) uploader { func newUploader(strategy uploadStrategy) *uploader {
pr, pw := io.Pipe() return &uploader{strategy: strategy, metrics: true}
hasher := md5.New()
mw := io.MultiWriter(pw, hasher)
return uploader{w: mw, pr: pr, pw: pw, md5: hasher, strategy: strategy, metrics: metrics}
} }
// Close implements the standard io.Closer interface: it closes the http client request. func newETagCheckUploader(strategy uploadStrategy, metrics bool) *uploader {
// This method will also wait for the connection to terminate and return any error occurred during the upload return &uploader{strategy: strategy, metrics: metrics, checkETag: true}
func (u *uploader) Close() error {
var closeError error
u.closeOnce.Do(func() {
closeError = u.pw.Close()
})
if closeError != nil {
return closeError
}
<-u.ctx.Done()
if err := u.ctx.Err(); err == context.DeadlineExceeded {
return err
}
return u.uploadError
} }
func (u *uploader) CloseWithError(err error) error { func hexString(h hash.Hash) string { return hex.EncodeToString(h.Sum(nil)) }
u.closeOnce.Do(func() {
u.pw.CloseWithError(err)
})
return nil
}
func (u *uploader) Write(p []byte) (int, error) { // Consume reads the reader until it reaches EOF or an error. It spawns a
return u.w.Write(p) // goroutine that waits for outerCtx to be done, after which the remote
} // file is deleted. The deadline applies to the upload performed inside
// Consume, not to outerCtx.
func (u *uploader) md5Sum() string { func (u *uploader) Consume(outerCtx context.Context, reader io.Reader, deadline time.Time) (_ int64, err error) {
if u.md5 == nil {
return ""
}
checksum := u.md5.Sum(nil)
return hex.EncodeToString(checksum)
}
// ETag returns the checksum of the uploaded object returned by the ObjectStorage provider via ETag Header.
// This method will wait until upload context is done before returning.
func (u *uploader) ETag() string {
<-u.ctx.Done()
return u.etag
}
func (u *uploader) Execute(ctx context.Context, deadline time.Time) {
if u.metrics { if u.metrics {
objectStorageUploadsOpen.Inc() objectStorageUploadsOpen.Inc()
defer func(started time.Time) {
objectStorageUploadsOpen.Dec()
objectStorageUploadTime.Observe(time.Since(started).Seconds())
if err != nil {
objectStorageUploadRequestsRequestFailed.Inc()
}
}(time.Now())
} }
uploadCtx, cancelFn := context.WithDeadline(ctx, deadline)
u.ctx = uploadCtx
if u.metrics {
go u.trackUploadTime()
}
uploadDone := make(chan struct{})
go u.cleanup(ctx, uploadDone)
go func() {
defer cancelFn()
defer close(uploadDone)
if u.metrics {
defer objectStorageUploadsOpen.Dec()
}
defer func() {
// This will be returned as error to the next write operation on the pipe
u.pr.CloseWithError(u.uploadError)
}()
err := u.strategy.Upload(uploadCtx, u.pr) defer func() {
// We do this mainly to abort S3 multipart uploads: it is not enough to
// "delete" them.
if err != nil { if err != nil {
u.uploadError = err u.strategy.Abort()
if u.metrics {
objectStorageUploadRequestsRequestFailed.Inc()
}
return
} }
}()
u.etag = u.strategy.ETag() go func() {
// Once gitlab-rails is done handling the request, we are supposed to
if u.md5 != nil { // delete the upload from its temporary location.
err := compareMD5(u.md5Sum(), u.etag) <-outerCtx.Done()
if err != nil { u.strategy.Delete()
log.ContextLogger(ctx).WithError(err).Error("error comparing MD5 checksum")
u.uploadError = err
if u.metrics {
objectStorageUploadRequestsRequestFailed.Inc()
}
}
}
}() }()
}
func (u *uploader) trackUploadTime() { uploadCtx, cancelFn := context.WithDeadline(outerCtx, deadline)
started := time.Now() defer cancelFn()
<-u.ctx.Done()
if u.metrics { var hasher hash.Hash
objectStorageUploadTime.Observe(time.Since(started).Seconds()) if u.checkETag {
hasher = md5.New()
reader = io.TeeReader(reader, hasher)
} }
}
func (u *uploader) cleanup(ctx context.Context, uploadDone chan struct{}) { cr := &countReader{r: reader}
// wait for the upload to finish if err := u.strategy.Upload(uploadCtx, cr); err != nil {
<-u.ctx.Done() return cr.n, err
}
<-uploadDone if u.checkETag {
if u.uploadError != nil { if err := compareMD5(hexString(hasher), u.strategy.ETag()); err != nil {
if u.metrics { log.ContextLogger(uploadCtx).WithError(err).Error("error comparing MD5 checksum")
objectStorageUploadRequestsRequestFailed.Inc() return cr.n, err
} }
u.strategy.Abort()
return
} }
// We have now successfully uploaded the file to object storage. Another return cr.n, nil
// goroutine will hand off the object to gitlab-rails.
<-ctx.Done()
// gitlab-rails is now done with the object so it's time to delete it.
u.strategy.Delete()
} }
func compareMD5(local, remote string) error { func compareMD5(local, remote string) error {
...@@ -188,3 +100,14 @@ func compareMD5(local, remote string) error { ...@@ -188,3 +100,14 @@ func compareMD5(local, remote string) error {
return nil return nil
} }
type countReader struct {
r io.Reader
n int64
}
func (cr *countReader) Read(p []byte) (int, error) {
nRead, err := cr.r.Read(p)
cr.n += int64(nRead)
return nRead, err
}
...@@ -123,7 +123,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { ...@@ -123,7 +123,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) {
require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo) require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo)
} }
require.Len(t, r.MultipartForm.Value, 12, "multipart form values") require.Len(t, r.MultipartForm.Value, 11, "multipart form values")
w.WriteHeader(202) w.WriteHeader(202)
fmt.Fprint(w, "RESPONSE") fmt.Fprint(w, "RESPONSE")
......
...@@ -79,7 +79,7 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest. ...@@ -79,7 +79,7 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.
require.NoError(t, r.ParseMultipartForm(100000)) require.NoError(t, r.ParseMultipartForm(100000))
const nValues = 11 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload, etag for just the upload (no metadata because we are not POSTing a valid zip file) const nValues = 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file)
require.Len(t, r.MultipartForm.Value, nValues) require.Len(t, r.MultipartForm.Value, nValues)
require.Empty(t, r.MultipartForm.File, "multipart form files") require.Empty(t, r.MultipartForm.File, "multipart form files")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment