Commit 4c91b2db authored by Jacob Vosmaer's avatar Jacob Vosmaer

Use io.Copy in gitaly smarthttp

parent a3be5472
...@@ -8,23 +8,12 @@ import ( ...@@ -8,23 +8,12 @@ import (
pbhelper "gitlab.com/gitlab-org/gitaly-proto/go/helper" pbhelper "gitlab.com/gitlab-org/gitaly-proto/go/helper"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc"
) )
type SmartHTTPClient struct { type SmartHTTPClient struct {
pb.SmartHTTPClient pb.SmartHTTPClient
} }
type uploadPackWriter struct {
pb.SmartHTTP_PostUploadPackClient
}
type receivePackWriter struct {
pb.SmartHTTP_PostReceivePackClient
}
const sendChunkSize = 16384
func (client *SmartHTTPClient) InfoRefsResponseWriterTo(ctx context.Context, repo *pb.Repository, rpc string) (io.WriterTo, error) { func (client *SmartHTTPClient) InfoRefsResponseWriterTo(ctx context.Context, repo *pb.Repository, rpc string) (io.WriterTo, error) {
rpcRequest := &pb.InfoRefsRequest{Repository: repo} rpcRequest := &pb.InfoRefsRequest{Repository: repo}
var c pbhelper.InfoRefsClient var c pbhelper.InfoRefsClient
...@@ -64,21 +53,31 @@ func (client *SmartHTTPClient) ReceivePack(repo *pb.Repository, GlId string, cli ...@@ -64,21 +53,31 @@ func (client *SmartHTTPClient) ReceivePack(repo *pb.Repository, GlId string, cli
return fmt.Errorf("initial request: %v", err) return fmt.Errorf("initial request: %v", err)
} }
waitc := make(chan error, 1) numStreams := 2
errC := make(chan error, numStreams)
go receiveGitalyResponse(stream, waitc, clientResponse, func() ([]byte, error) { go func() {
response, err := stream.Recv() rr := pbhelper.NewReceiveReader(func() ([]byte, error) {
return response.GetData(), err response, err := stream.Recv()
}) return response.GetData(), err
})
_, err := io.Copy(clientResponse, rr)
errC <- err
}()
_, sendErr := io.Copy(receivePackWriter{stream}, clientRequest) go func() {
stream.CloseSend() sw := pbhelper.NewSendWriter(func(data []byte) error {
return stream.Send(&pb.PostReceivePackRequest{Data: data})
})
_, err := io.Copy(sw, clientRequest)
stream.CloseSend()
errC <- err
}()
if recvErr := <-waitc; recvErr != nil { for i := 0; i < numStreams; i++ {
return recvErr if err := <-errC; err != nil {
} return err
if sendErr != nil { }
return fmt.Errorf("send: %v", sendErr)
} }
return nil return nil
...@@ -101,60 +100,32 @@ func (client *SmartHTTPClient) UploadPack(repo *pb.Repository, clientRequest io. ...@@ -101,60 +100,32 @@ func (client *SmartHTTPClient) UploadPack(repo *pb.Repository, clientRequest io.
return fmt.Errorf("initial request: %v", err) return fmt.Errorf("initial request: %v", err)
} }
waitc := make(chan error, 1) numStreams := 2
errC := make(chan error, numStreams)
go receiveGitalyResponse(stream, waitc, clientResponse, func() ([]byte, error) {
response, err := stream.Recv()
return response.GetData(), err
})
_, sendErr := io.Copy(uploadPackWriter{stream}, clientRequest) go func() {
stream.CloseSend() rr := pbhelper.NewReceiveReader(func() ([]byte, error) {
response, err := stream.Recv()
if recvErr := <-waitc; recvErr != nil { return response.GetData(), err
return recvErr })
} _, err := io.Copy(clientResponse, rr)
if sendErr != nil { errC <- err
return fmt.Errorf("send: %v", sendErr)
}
return nil
}
func receiveGitalyResponse(cs grpc.ClientStream, waitc chan error, clientResponse io.Writer, receiver func() ([]byte, error)) {
defer func() {
close(waitc)
cs.CloseSend()
}() }()
for { go func() {
data, err := receiver() sw := pbhelper.NewSendWriter(func(data []byte) error {
if err != nil { return stream.Send(&pb.PostUploadPackRequest{Data: data})
if err != io.EOF { })
waitc <- fmt.Errorf("receive: %v", err) _, err := io.Copy(sw, clientRequest)
} stream.CloseSend()
return errC <- err
} }()
if _, err := clientResponse.Write(data); err != nil { for i := 0; i < numStreams; i++ {
waitc <- fmt.Errorf("write: %v", err) if err := <-errC; err != nil {
return return err
} }
} }
}
func (rw uploadPackWriter) Write(p []byte) (int, error) {
resp := &pb.PostUploadPackRequest{Data: p}
if err := rw.Send(resp); err != nil {
return 0, err
}
return len(p), nil
}
func (rw receivePackWriter) Write(p []byte) (int, error) { return nil
resp := &pb.PostReceivePackRequest{Data: p}
if err := rw.Send(resp); err != nil {
return 0, err
}
return len(p), nil
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment