Commit b8fa6188 authored by Petar Maymounkov's avatar Petar Maymounkov Committed by Russ Cox

http: introduce Header type, implement with net/textproto

textproto: introduce Header type
websocket: use new interface to access Header

R=rsc, mattn
CC=golang-dev
https://golang.org/cl/4185053
parent 07cc8b9a
...@@ -10,6 +10,7 @@ GOFILES=\ ...@@ -10,6 +10,7 @@ GOFILES=\
client.go\ client.go\
dump.go\ dump.go\
fs.go\ fs.go\
header.go\
lex.go\ lex.go\
persist.go\ persist.go\
request.go\ request.go\
......
...@@ -85,9 +85,9 @@ func send(req *Request) (resp *Response, err os.Error) { ...@@ -85,9 +85,9 @@ func send(req *Request) (resp *Response, err os.Error) {
encoded := make([]byte, enc.EncodedLen(len(info))) encoded := make([]byte, enc.EncodedLen(len(info)))
enc.Encode(encoded, []byte(info)) enc.Encode(encoded, []byte(info))
if req.Header == nil { if req.Header == nil {
req.Header = make(map[string]string) req.Header = make(Header)
} }
req.Header["Authorization"] = "Basic " + string(encoded) req.Header.Set("Authorization", "Basic "+string(encoded))
} }
var proxyURL *URL var proxyURL *URL
...@@ -130,7 +130,7 @@ func send(req *Request) (resp *Response, err os.Error) { ...@@ -130,7 +130,7 @@ func send(req *Request) (resp *Response, err os.Error) {
if req.URL.Scheme == "http" { if req.URL.Scheme == "http" {
// Include proxy http header if needed. // Include proxy http header if needed.
if proxyAuth != "" { if proxyAuth != "" {
req.Header["Proxy-Authorization"] = proxyAuth req.Header.Set("Proxy-Authorization", proxyAuth)
} }
} else { // https } else { // https
if proxyURL != nil { if proxyURL != nil {
...@@ -241,7 +241,7 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { ...@@ -241,7 +241,7 @@ func Get(url string) (r *Response, finalURL string, err os.Error) {
} }
if shouldRedirect(r.StatusCode) { if shouldRedirect(r.StatusCode) {
r.Body.Close() r.Body.Close()
if url = r.GetHeader("Location"); url == "" { if url = r.Header.Get("Location"); url == "" {
err = os.ErrorString(fmt.Sprintf("%d response missing Location header", r.StatusCode)) err = os.ErrorString(fmt.Sprintf("%d response missing Location header", r.StatusCode))
break break
} }
...@@ -266,8 +266,8 @@ func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Erro ...@@ -266,8 +266,8 @@ func Post(url string, bodyType string, body io.Reader) (r *Response, err os.Erro
req.ProtoMinor = 1 req.ProtoMinor = 1
req.Close = true req.Close = true
req.Body = nopCloser{body} req.Body = nopCloser{body}
req.Header = map[string]string{ req.Header = Header{
"Content-Type": bodyType, "Content-Type": {bodyType},
} }
req.TransferEncoding = []string{"chunked"} req.TransferEncoding = []string{"chunked"}
...@@ -291,9 +291,9 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) { ...@@ -291,9 +291,9 @@ func PostForm(url string, data map[string]string) (r *Response, err os.Error) {
req.Close = true req.Close = true
body := urlencode(data) body := urlencode(data)
req.Body = nopCloser{body} req.Body = nopCloser{body}
req.Header = map[string]string{ req.Header = Header{
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": {"application/x-www-form-urlencoded"},
"Content-Length": strconv.Itoa(body.Len()), "Content-Length": {strconv.Itoa(body.Len())},
} }
req.ContentLength = int64(body.Len()) req.ContentLength = int64(body.Len())
......
...@@ -104,7 +104,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { ...@@ -104,7 +104,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
} }
} }
if t, _ := time.Parse(TimeFormat, r.Header["If-Modified-Since"]); t != nil && d.Mtime_ns/1e9 <= t.Seconds() { if t, _ := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); t != nil && d.Mtime_ns/1e9 <= t.Seconds() {
w.WriteHeader(StatusNotModified) w.WriteHeader(StatusNotModified)
return return
} }
...@@ -153,7 +153,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { ...@@ -153,7 +153,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) {
// handle Content-Range header. // handle Content-Range header.
// TODO(adg): handle multiple ranges // TODO(adg): handle multiple ranges
ranges, err := parseRange(r.Header["Range"], size) ranges, err := parseRange(r.Header.Get("Range"), size)
if err != nil || len(ranges) > 1 { if err != nil || len(ranges) > 1 {
Error(w, err.String(), StatusRequestedRangeNotSatisfiable) Error(w, err.String(), StatusRequestedRangeNotSatisfiable)
return return
......
...@@ -109,7 +109,7 @@ func TestServeFile(t *testing.T) { ...@@ -109,7 +109,7 @@ func TestServeFile(t *testing.T) {
// set up the Request (re-used for all tests) // set up the Request (re-used for all tests)
var req Request var req Request
req.Header = make(map[string]string) req.Header = make(Header)
if req.URL, err = ParseURL("http://" + serverAddr + "/ServeFile"); err != nil { if req.URL, err = ParseURL("http://" + serverAddr + "/ServeFile"); err != nil {
t.Fatal("ParseURL:", err) t.Fatal("ParseURL:", err)
} }
...@@ -123,9 +123,9 @@ func TestServeFile(t *testing.T) { ...@@ -123,9 +123,9 @@ func TestServeFile(t *testing.T) {
// Range tests // Range tests
for _, rt := range ServeFileRangeTests { for _, rt := range ServeFileRangeTests {
req.Header["Range"] = "bytes=" + rt.r req.Header.Set("Range", "bytes="+rt.r)
if rt.r == "" { if rt.r == "" {
req.Header["Range"] = "" req.Header["Range"] = nil
} }
r, body := getBody(t, req) r, body := getBody(t, req)
if r.StatusCode != rt.code { if r.StatusCode != rt.code {
...@@ -138,8 +138,9 @@ func TestServeFile(t *testing.T) { ...@@ -138,8 +138,9 @@ func TestServeFile(t *testing.T) {
if rt.r == "" { if rt.r == "" {
h = "" h = ""
} }
if r.Header["Content-Range"] != h { cr := r.Header.Get("Content-Range")
t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, r.Header["Content-Range"], h) if cr != h {
t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h)
} }
if !equal(body, file[rt.start:rt.end]) { if !equal(body, file[rt.start:rt.end]) {
t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end])
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import "net/textproto"
// A Header represents the key-value pairs in an HTTP header.
type Header map[string][]string
// Add adds the key, value pair to the header.
// It appends to any existing values associated with key.
func (h Header) Add(key, value string) {
textproto.MIMEHeader(h).Add(key, value)
}
// Set sets the header entries associated with key to
// the single element value. It replaces any existing
// values associated with key.
func (h Header) Set(key, value string) {
textproto.MIMEHeader(h).Set(key, value)
}
// Get gets the first value associated with the given key.
// If there are no values associated with the key, Get returns "".
// Get is a convenience method. For more complex queries,
// access the map directly.
func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
// Del deletes the values associated with key.
func (h Header) Del(key string) {
textproto.MIMEHeader(h).Del(key)
}
// CanonicalHeaderKey returns the canonical format of the
// header key s. The canonicalization converts the first
// letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
...@@ -50,14 +50,14 @@ var reqTests = []reqTest{ ...@@ -50,14 +50,14 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: map[string]string{ Header: Header{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
"Accept-Language": "en-us,en;q=0.5", "Accept-Language": {"en-us,en;q=0.5"},
"Accept-Encoding": "gzip,deflate", "Accept-Encoding": {"gzip,deflate"},
"Accept-Charset": "ISO-8859-1,utf-8;q=0.7,*;q=0.7", "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
"Keep-Alive": "300", "Keep-Alive": {"300"},
"Proxy-Connection": "keep-alive", "Proxy-Connection": {"keep-alive"},
"Content-Length": "7", "Content-Length": {"7"},
}, },
Close: false, Close: false,
ContentLength: 7, ContentLength: 7,
...@@ -93,7 +93,7 @@ var reqTests = []reqTest{ ...@@ -93,7 +93,7 @@ var reqTests = []reqTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: map[string]string{}, Header: map[string][]string{},
Close: false, Close: false,
ContentLength: -1, ContentLength: -1,
Host: "test", Host: "test",
......
...@@ -11,13 +11,13 @@ package http ...@@ -11,13 +11,13 @@ package http
import ( import (
"bufio" "bufio"
"bytes"
"container/vector" "container/vector"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
"mime/multipart" "mime/multipart"
"net/textproto"
"os" "os"
"strconv" "strconv"
"strings" "strings"
...@@ -90,7 +90,7 @@ type Request struct { ...@@ -90,7 +90,7 @@ type Request struct {
// The request parser implements this by canonicalizing the // The request parser implements this by canonicalizing the
// name, making the first character and any characters // name, making the first character and any characters
// following a hyphen uppercase and the rest lowercase. // following a hyphen uppercase and the rest lowercase.
Header map[string]string Header Header
// The message body. // The message body.
Body io.ReadCloser Body io.ReadCloser
...@@ -133,7 +133,7 @@ type Request struct { ...@@ -133,7 +133,7 @@ type Request struct {
// Trailer maps trailer keys to values. Like for Header, if the // Trailer maps trailer keys to values. Like for Header, if the
// response has multiple trailer lines with the same key, they will be // response has multiple trailer lines with the same key, they will be
// concatenated, delimited by commas. // concatenated, delimited by commas.
Trailer map[string]string Trailer Header
} }
// ProtoAtLeast returns whether the HTTP protocol used // ProtoAtLeast returns whether the HTTP protocol used
...@@ -146,8 +146,8 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { ...@@ -146,8 +146,8 @@ func (r *Request) ProtoAtLeast(major, minor int) bool {
// MultipartReader returns a MIME multipart reader if this is a // MultipartReader returns a MIME multipart reader if this is a
// multipart/form-data POST request, else returns nil and an error. // multipart/form-data POST request, else returns nil and an error.
func (r *Request) MultipartReader() (multipart.Reader, os.Error) { func (r *Request) MultipartReader() (multipart.Reader, os.Error) {
v, ok := r.Header["Content-Type"] v := r.Header.Get("Content-Type")
if !ok { if v == "" {
return nil, ErrNotMultipart return nil, ErrNotMultipart
} }
d, params := mime.ParseMediaType(v) d, params := mime.ParseMediaType(v)
...@@ -297,78 +297,6 @@ func readLine(b *bufio.Reader) (s string, err os.Error) { ...@@ -297,78 +297,6 @@ func readLine(b *bufio.Reader) (s string, err os.Error) {
return string(p), nil return string(p), nil
} }
var colon = []byte{':'}
// Read a key/value pair from b.
// A key/value has the form Key: Value\r\n
// and the Value can continue on multiple lines if each continuation line
// starts with a space.
func readKeyValue(b *bufio.Reader) (key, value string, err os.Error) {
line, e := readLineBytes(b)
if e != nil {
return "", "", e
}
if len(line) == 0 {
return "", "", nil
}
// Scan first line for colon.
i := bytes.Index(line, colon)
if i < 0 {
goto Malformed
}
key = string(line[0:i])
if strings.Contains(key, " ") {
// Key field has space - no good.
goto Malformed
}
// Skip initial space before value.
for i++; i < len(line); i++ {
if line[i] != ' ' {
break
}
}
value = string(line[i:])
// Look for extension lines, which must begin with space.
for {
c, e := b.ReadByte()
if c != ' ' {
if e != os.EOF {
b.UnreadByte()
}
break
}
// Eat leading space.
for c == ' ' {
if c, e = b.ReadByte(); e != nil {
if e == os.EOF {
e = io.ErrUnexpectedEOF
}
return "", "", e
}
}
b.UnreadByte()
// Read the rest of the line and add to value.
if line, e = readLineBytes(b); e != nil {
return "", "", e
}
value += " " + string(line)
if len(value) >= maxValueLength {
return "", "", &badStringError{"value too long for key", key}
}
}
return key, value, nil
Malformed:
return "", "", &badStringError{"malformed header line", string(line)}
}
// Convert decimal at s[i:len(s)] to integer, // Convert decimal at s[i:len(s)] to integer,
// returning value, string position where the digits stopped, // returning value, string position where the digits stopped,
// and whether there was a valid number (digits, not too big). // and whether there was a valid number (digits, not too big).
...@@ -404,43 +332,6 @@ func parseHTTPVersion(vers string) (int, int, bool) { ...@@ -404,43 +332,6 @@ func parseHTTPVersion(vers string) (int, int, bool) {
return major, minor, true return major, minor, true
} }
// CanonicalHeaderKey returns the canonical format of the
// HTTP header key s. The canonicalization converts the first
// letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding".
func CanonicalHeaderKey(s string) string {
// canonicalize: first letter upper case
// and upper case after each dash.
// (Host, User-Agent, If-Modified-Since).
// HTTP headers are ASCII only, so no Unicode issues.
var a []byte
upper := true
for i := 0; i < len(s); i++ {
v := s[i]
if upper && 'a' <= v && v <= 'z' {
if a == nil {
a = []byte(s)
}
a[i] = v + 'A' - 'a'
}
if !upper && 'A' <= v && v <= 'Z' {
if a == nil {
a = []byte(s)
}
a[i] = v + 'a' - 'A'
}
upper = false
if v == '-' {
upper = true
}
}
if a != nil {
return string(a)
}
return s
}
type chunkedReader struct { type chunkedReader struct {
r *bufio.Reader r *bufio.Reader
n uint64 // unread bytes in chunk n uint64 // unread bytes in chunk
...@@ -506,11 +397,16 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { ...@@ -506,11 +397,16 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) {
// ReadRequest reads and parses a request from b. // ReadRequest reads and parses a request from b.
func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) {
tp := textproto.NewReader(b)
req = new(Request) req = new(Request)
// First line: GET /index.html HTTP/1.0 // First line: GET /index.html HTTP/1.0
var s string var s string
if s, err = readLine(b); err != nil { if s, err = tp.ReadLine(); err != nil {
if err == os.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err return nil, err
} }
...@@ -529,32 +425,11 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { ...@@ -529,32 +425,11 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) {
} }
// Subsequent lines: Key: value. // Subsequent lines: Key: value.
nheader := 0 mimeHeader, err := tp.ReadMIMEHeader()
req.Header = make(map[string]string) if err != nil {
for {
var key, value string
if key, value, err = readKeyValue(b); err != nil {
return nil, err return nil, err
} }
if key == "" { req.Header = Header(mimeHeader)
break
}
if nheader++; nheader >= maxHeaderLines {
return nil, ErrHeaderTooLong
}
key = CanonicalHeaderKey(key)
// RFC 2616 says that if you send the same header key
// multiple times, it has to be semantically equivalent
// to concatenating the values separated by commas.
oldvalue, present := req.Header[key]
if present {
req.Header[key] = oldvalue + "," + value
} else {
req.Header[key] = value
}
}
// RFC2616: Must treat // RFC2616: Must treat
// GET /index.html HTTP/1.1 // GET /index.html HTTP/1.1
...@@ -565,18 +440,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { ...@@ -565,18 +440,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) {
// the same. In the second case, any Host line is ignored. // the same. In the second case, any Host line is ignored.
req.Host = req.URL.Host req.Host = req.URL.Host
if req.Host == "" { if req.Host == "" {
req.Host = req.Header["Host"] req.Host = req.Header.Get("Host")
} }
req.Header["Host"] = "", false req.Header.Del("Host")
fixPragmaCacheControl(req.Header) fixPragmaCacheControl(req.Header)
// Pull out useful fields as a convenience to clients. // Pull out useful fields as a convenience to clients.
req.Referer = req.Header["Referer"] req.Referer = req.Header.Get("Referer")
req.Header["Referer"] = "", false req.Header.Del("Referer")
req.UserAgent = req.Header["User-Agent"] req.UserAgent = req.Header.Get("User-Agent")
req.Header["User-Agent"] = "", false req.Header.Del("User-Agent")
// TODO: Parse specific header values: // TODO: Parse specific header values:
// Accept // Accept
...@@ -662,7 +537,7 @@ func (r *Request) ParseForm() (err os.Error) { ...@@ -662,7 +537,7 @@ func (r *Request) ParseForm() (err os.Error) {
if r.Body == nil { if r.Body == nil {
return os.ErrorString("missing form body") return os.ErrorString("missing form body")
} }
ct := r.Header["Content-Type"] ct := r.Header.Get("Content-Type")
switch strings.Split(ct, ";", 2)[0] { switch strings.Split(ct, ";", 2)[0] {
case "text/plain", "application/x-www-form-urlencoded", "": case "text/plain", "application/x-www-form-urlencoded", "":
b, e := ioutil.ReadAll(r.Body) b, e := ioutil.ReadAll(r.Body)
...@@ -697,17 +572,12 @@ func (r *Request) FormValue(key string) string { ...@@ -697,17 +572,12 @@ func (r *Request) FormValue(key string) string {
} }
func (r *Request) expectsContinue() bool { func (r *Request) expectsContinue() bool {
expectation, ok := r.Header["Expect"] return strings.ToLower(r.Header.Get("Expect")) == "100-continue"
return ok && strings.ToLower(expectation) == "100-continue"
} }
func (r *Request) wantsHttp10KeepAlive() bool { func (r *Request) wantsHttp10KeepAlive() bool {
if r.ProtoMajor != 1 || r.ProtoMinor != 0 { if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
return false return false
} }
value, exists := r.Header["Connection"] return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "keep-alive")
if !exists {
return false
}
return strings.Contains(strings.ToLower(value), "keep-alive")
} }
...@@ -74,7 +74,9 @@ func TestQuery(t *testing.T) { ...@@ -74,7 +74,9 @@ func TestQuery(t *testing.T) {
func TestPostQuery(t *testing.T) { func TestPostQuery(t *testing.T) {
req := &Request{Method: "POST"} req := &Request{Method: "POST"}
req.URL, _ = ParseURL("http://www.google.com/search?q=foo&q=bar&both=x") req.URL, _ = ParseURL("http://www.google.com/search?q=foo&q=bar&both=x")
req.Header = map[string]string{"Content-Type": "application/x-www-form-urlencoded; boo!"} req.Header = Header{
"Content-Type": {"application/x-www-form-urlencoded; boo!"},
}
req.Body = nopCloser{strings.NewReader("z=post&both=y")} req.Body = nopCloser{strings.NewReader("z=post&both=y")}
if q := req.FormValue("q"); q != "foo" { if q := req.FormValue("q"); q != "foo" {
t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
...@@ -87,18 +89,18 @@ func TestPostQuery(t *testing.T) { ...@@ -87,18 +89,18 @@ func TestPostQuery(t *testing.T) {
} }
} }
type stringMap map[string]string type stringMap map[string][]string
type parseContentTypeTest struct { type parseContentTypeTest struct {
contentType stringMap contentType stringMap
error bool error bool
} }
var parseContentTypeTests = []parseContentTypeTest{ var parseContentTypeTests = []parseContentTypeTest{
{contentType: stringMap{"Content-Type": "text/plain"}}, {contentType: stringMap{"Content-Type": {"text/plain"}}},
{contentType: stringMap{"Content-Type": ""}}, {contentType: stringMap{}}, // Non-existent keys are not placed. The value nil is illegal.
{contentType: stringMap{"Content-Type": "text/plain; boundary="}}, {contentType: stringMap{"Content-Type": {"text/plain; boundary="}}},
{ {
contentType: stringMap{"Content-Type": "application/unknown"}, contentType: stringMap{"Content-Type": {"application/unknown"}},
error: true, error: true,
}, },
} }
...@@ -107,7 +109,7 @@ func TestPostContentTypeParsing(t *testing.T) { ...@@ -107,7 +109,7 @@ func TestPostContentTypeParsing(t *testing.T) {
for i, test := range parseContentTypeTests { for i, test := range parseContentTypeTests {
req := &Request{ req := &Request{
Method: "POST", Method: "POST",
Header: test.contentType, Header: Header(test.contentType),
Body: nopCloser{bytes.NewBufferString("body")}, Body: nopCloser{bytes.NewBufferString("body")},
} }
err := req.ParseForm() err := req.ParseForm()
...@@ -123,7 +125,7 @@ func TestPostContentTypeParsing(t *testing.T) { ...@@ -123,7 +125,7 @@ func TestPostContentTypeParsing(t *testing.T) {
func TestMultipartReader(t *testing.T) { func TestMultipartReader(t *testing.T) {
req := &Request{ req := &Request{
Method: "POST", Method: "POST",
Header: stringMap{"Content-Type": `multipart/form-data; boundary="foo123"`}, Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
Body: nopCloser{new(bytes.Buffer)}, Body: nopCloser{new(bytes.Buffer)},
} }
multipart, err := req.MultipartReader() multipart, err := req.MultipartReader()
...@@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { ...@@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) {
t.Errorf("expected multipart; error: %v", err) t.Errorf("expected multipart; error: %v", err)
} }
req.Header = stringMap{"Content-Type": "text/plain"} req.Header = Header{"Content-Type": {"text/plain"}}
multipart, err = req.MultipartReader() multipart, err = req.MultipartReader()
if multipart != nil { if multipart != nil {
t.Errorf("unexpected multipart for text/plain") t.Errorf("unexpected multipart for text/plain")
......
...@@ -34,13 +34,13 @@ var reqWriteTests = []reqWriteTest{ ...@@ -34,13 +34,13 @@ var reqWriteTests = []reqWriteTest{
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: map[string]string{ Header: Header{
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
"Accept-Charset": "ISO-8859-1,utf-8;q=0.7,*;q=0.7", "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
"Accept-Encoding": "gzip,deflate", "Accept-Encoding": {"gzip,deflate"},
"Accept-Language": "en-us,en;q=0.5", "Accept-Language": {"en-us,en;q=0.5"},
"Keep-Alive": "300", "Keep-Alive": {"300"},
"Proxy-Connection": "keep-alive", "Proxy-Connection": {"keep-alive"},
}, },
Body: nil, Body: nil,
Close: false, Close: false,
...@@ -53,10 +53,10 @@ var reqWriteTests = []reqWriteTest{ ...@@ -53,10 +53,10 @@ var reqWriteTests = []reqWriteTest{
"GET http://www.techcrunch.com/ HTTP/1.1\r\n" + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
"Host: www.techcrunch.com\r\n" + "Host: www.techcrunch.com\r\n" +
"User-Agent: Fake\r\n" + "User-Agent: Fake\r\n" +
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
"Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
"Accept-Encoding: gzip,deflate\r\n" + "Accept-Encoding: gzip,deflate\r\n" +
"Accept-Language: en-us,en;q=0.5\r\n" + "Accept-Language: en-us,en;q=0.5\r\n" +
"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
"Keep-Alive: 300\r\n" + "Keep-Alive: 300\r\n" +
"Proxy-Connection: keep-alive\r\n\r\n", "Proxy-Connection: keep-alive\r\n\r\n",
}, },
...@@ -71,7 +71,7 @@ var reqWriteTests = []reqWriteTest{ ...@@ -71,7 +71,7 @@ var reqWriteTests = []reqWriteTest{
}, },
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: map[string]string{}, Header: map[string][]string{},
Body: nopCloser{bytes.NewBufferString("abcdef")}, Body: nopCloser{bytes.NewBufferString("abcdef")},
TransferEncoding: []string{"chunked"}, TransferEncoding: []string{"chunked"},
}, },
...@@ -93,7 +93,7 @@ var reqWriteTests = []reqWriteTest{ ...@@ -93,7 +93,7 @@ var reqWriteTests = []reqWriteTest{
}, },
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: map[string]string{}, Header: map[string][]string{},
Close: true, Close: true,
Body: nopCloser{bytes.NewBufferString("abcdef")}, Body: nopCloser{bytes.NewBufferString("abcdef")},
TransferEncoding: []string{"chunked"}, TransferEncoding: []string{"chunked"},
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"net/textproto"
"os" "os"
"sort" "sort"
"strconv" "strconv"
...@@ -43,7 +44,7 @@ type Response struct { ...@@ -43,7 +44,7 @@ type Response struct {
// omitted from Header. // omitted from Header.
// //
// Keys in the map are canonicalized (see CanonicalHeaderKey). // Keys in the map are canonicalized (see CanonicalHeaderKey).
Header map[string]string Header Header
// Body represents the response body. // Body represents the response body.
Body io.ReadCloser Body io.ReadCloser
...@@ -66,7 +67,7 @@ type Response struct { ...@@ -66,7 +67,7 @@ type Response struct {
// Trailer maps trailer keys to values. Like for Header, if the // Trailer maps trailer keys to values. Like for Header, if the
// response has multiple trailer lines with the same key, they will be // response has multiple trailer lines with the same key, they will be
// concatenated, delimited by commas. // concatenated, delimited by commas.
Trailer map[string]string Trailer map[string][]string
} }
// ReadResponse reads and returns an HTTP response from r. The RequestMethod // ReadResponse reads and returns an HTTP response from r. The RequestMethod
...@@ -76,13 +77,17 @@ type Response struct { ...@@ -76,13 +77,17 @@ type Response struct {
// key/value pairs included in the response trailer. // key/value pairs included in the response trailer.
func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
tp := textproto.NewReader(r)
resp = new(Response) resp = new(Response)
resp.RequestMethod = strings.ToUpper(requestMethod) resp.RequestMethod = strings.ToUpper(requestMethod)
// Parse the first line of the response. // Parse the first line of the response.
line, err := readLine(r) line, err := tp.ReadLine()
if err != nil { if err != nil {
if err == os.EOF {
err = io.ErrUnexpectedEOF
}
return nil, err return nil, err
} }
f := strings.Split(line, " ", 3) f := strings.Split(line, " ", 3)
...@@ -106,21 +111,11 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os ...@@ -106,21 +111,11 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os
} }
// Parse the response headers. // Parse the response headers.
nheader := 0 mimeHeader, err := tp.ReadMIMEHeader()
resp.Header = make(map[string]string)
for {
key, value, err := readKeyValue(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if key == "" { resp.Header = Header(mimeHeader)
break // end of response header
}
if nheader++; nheader >= maxHeaderLines {
return nil, ErrHeaderTooLong
}
resp.AddHeader(key, value)
}
fixPragmaCacheControl(resp.Header) fixPragmaCacheControl(resp.Header)
...@@ -136,34 +131,14 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os ...@@ -136,34 +131,14 @@ func ReadResponse(r *bufio.Reader, requestMethod string) (resp *Response, err os
// Pragma: no-cache // Pragma: no-cache
// like // like
// Cache-Control: no-cache // Cache-Control: no-cache
func fixPragmaCacheControl(header map[string]string) { func fixPragmaCacheControl(header Header) {
if header["Pragma"] == "no-cache" { if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
if _, presentcc := header["Cache-Control"]; !presentcc { if _, presentcc := header["Cache-Control"]; !presentcc {
header["Cache-Control"] = "no-cache" header["Cache-Control"] = []string{"no-cache"}
} }
} }
} }
// AddHeader adds a value under the given key. Keys are not case sensitive.
func (r *Response) AddHeader(key, value string) {
key = CanonicalHeaderKey(key)
oldValues, oldValuesPresent := r.Header[key]
if oldValuesPresent {
r.Header[key] = oldValues + "," + value
} else {
r.Header[key] = value
}
}
// GetHeader returns the value of the response header with the given key.
// If there were multiple headers with this key, their values are concatenated,
// with a comma delimiter. If there were no response headers with the given
// key, GetHeader returns an empty string. Keys are not case sensitive.
func (r *Response) GetHeader(key string) (value string) {
return r.Header[CanonicalHeaderKey(key)]
}
// ProtoAtLeast returns whether the HTTP protocol used // ProtoAtLeast returns whether the HTTP protocol used
// in the response is at least major.minor. // in the response is at least major.minor.
func (r *Response) ProtoAtLeast(major, minor int) bool { func (r *Response) ProtoAtLeast(major, minor int) bool {
...@@ -231,21 +206,20 @@ func (resp *Response) Write(w io.Writer) os.Error { ...@@ -231,21 +206,20 @@ func (resp *Response) Write(w io.Writer) os.Error {
return nil return nil
} }
func writeSortedKeyValue(w io.Writer, kvm map[string]string, exclude map[string]bool) os.Error { func writeSortedKeyValue(w io.Writer, kvm map[string][]string, exclude map[string]bool) os.Error {
kva := make([]string, len(kvm)) keys := make([]string, 0, len(kvm))
i := 0 for k := range kvm {
for k, v := range kvm {
if !exclude[k] { if !exclude[k] {
kva[i] = fmt.Sprint(k + ": " + v + "\r\n") keys = append(keys, k)
i++
} }
} }
kva = kva[0:i] sort.SortStrings(keys)
sort.SortStrings(kva) for _, k := range keys {
for _, l := range kva { for _, v := range kvm[k] {
if _, err := io.WriteString(w, l); err != nil { if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil {
return err return err
} }
} }
}
return nil return nil
} }
...@@ -34,8 +34,8 @@ var respTests = []respTest{ ...@@ -34,8 +34,8 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{ Header: Header{
"Connection": "close", // TODO(rsc): Delete? "Connection": {"close"}, // TODO(rsc): Delete?
}, },
Close: true, Close: true,
ContentLength: -1, ContentLength: -1,
...@@ -100,9 +100,9 @@ var respTests = []respTest{ ...@@ -100,9 +100,9 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{ Header: Header{
"Connection": "close", // TODO(rsc): Delete? "Connection": {"close"}, // TODO(rsc): Delete?
"Content-Length": "10", // TODO(rsc): Delete? "Content-Length": {"10"}, // TODO(rsc): Delete?
}, },
Close: true, Close: true,
ContentLength: 10, ContentLength: 10,
...@@ -128,7 +128,7 @@ var respTests = []respTest{ ...@@ -128,7 +128,7 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: Header{},
Close: true, Close: true,
ContentLength: -1, ContentLength: -1,
TransferEncoding: []string{"chunked"}, TransferEncoding: []string{"chunked"},
...@@ -155,7 +155,7 @@ var respTests = []respTest{ ...@@ -155,7 +155,7 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: Header{},
Close: true, Close: true,
ContentLength: -1, // TODO(rsc): Fix? ContentLength: -1, // TODO(rsc): Fix?
TransferEncoding: []string{"chunked"}, TransferEncoding: []string{"chunked"},
...@@ -175,7 +175,7 @@ var respTests = []respTest{ ...@@ -175,7 +175,7 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: Header{},
Close: true, Close: true,
ContentLength: -1, ContentLength: -1,
}, },
...@@ -194,7 +194,7 @@ var respTests = []respTest{ ...@@ -194,7 +194,7 @@ var respTests = []respTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: Header{},
Close: true, Close: true,
ContentLength: -1, ContentLength: -1,
}, },
......
...@@ -22,7 +22,7 @@ var respWriteTests = []respWriteTest{ ...@@ -22,7 +22,7 @@ var respWriteTests = []respWriteTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: map[string][]string{},
Body: nopCloser{bytes.NewBufferString("abcdef")}, Body: nopCloser{bytes.NewBufferString("abcdef")},
ContentLength: 6, ContentLength: 6,
}, },
...@@ -38,7 +38,7 @@ var respWriteTests = []respWriteTest{ ...@@ -38,7 +38,7 @@ var respWriteTests = []respWriteTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 0, ProtoMinor: 0,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: map[string][]string{},
Body: nopCloser{bytes.NewBufferString("abcdef")}, Body: nopCloser{bytes.NewBufferString("abcdef")},
ContentLength: -1, ContentLength: -1,
}, },
...@@ -53,7 +53,7 @@ var respWriteTests = []respWriteTest{ ...@@ -53,7 +53,7 @@ var respWriteTests = []respWriteTest{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
RequestMethod: "GET", RequestMethod: "GET",
Header: map[string]string{}, Header: map[string][]string{},
Body: nopCloser{bytes.NewBufferString("abcdef")}, Body: nopCloser{bytes.NewBufferString("abcdef")},
ContentLength: 6, ContentLength: 6,
TransferEncoding: []string{"chunked"}, TransferEncoding: []string{"chunked"},
......
...@@ -197,7 +197,7 @@ func TestHostHandlers(t *testing.T) { ...@@ -197,7 +197,7 @@ func TestHostHandlers(t *testing.T) {
t.Errorf("reading response: %v", err) t.Errorf("reading response: %v", err)
continue continue
} }
s := r.Header["Result"] s := r.Header.Get("Result")
if s != vt.expected { if s != vt.expected {
t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected) t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
} }
......
...@@ -21,7 +21,7 @@ type transferWriter struct { ...@@ -21,7 +21,7 @@ type transferWriter struct {
ContentLength int64 ContentLength int64
Close bool Close bool
TransferEncoding []string TransferEncoding []string
Trailer map[string]string Trailer Header
} }
func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) {
...@@ -159,7 +159,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { ...@@ -159,7 +159,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) {
type transferReader struct { type transferReader struct {
// Input // Input
Header map[string]string Header Header
StatusCode int StatusCode int
RequestMethod string RequestMethod string
ProtoMajor int ProtoMajor int
...@@ -169,7 +169,7 @@ type transferReader struct { ...@@ -169,7 +169,7 @@ type transferReader struct {
ContentLength int64 ContentLength int64
TransferEncoding []string TransferEncoding []string
Close bool Close bool
Trailer map[string]string Trailer Header
} }
// bodyAllowedForStatus returns whether a given response status code // bodyAllowedForStatus returns whether a given response status code
...@@ -289,14 +289,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) { ...@@ -289,14 +289,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err os.Error) {
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
// Sanitize transfer encoding // Sanitize transfer encoding
func fixTransferEncoding(header map[string]string) ([]string, os.Error) { func fixTransferEncoding(header Header) ([]string, os.Error) {
raw, present := header["Transfer-Encoding"] raw, present := header["Transfer-Encoding"]
if !present { if !present {
return nil, nil return nil, nil
} }
header["Transfer-Encoding"] = "", false header["Transfer-Encoding"] = nil, false
encodings := strings.Split(raw, ",", -1) encodings := strings.Split(raw[0], ",", -1)
te := make([]string, 0, len(encodings)) te := make([]string, 0, len(encodings))
// TODO: Even though we only support "identity" and "chunked" // TODO: Even though we only support "identity" and "chunked"
// encodings, the loop below is designed with foresight. One // encodings, the loop below is designed with foresight. One
...@@ -321,7 +321,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) { ...@@ -321,7 +321,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) {
// Chunked encoding trumps Content-Length. See RFC 2616 // Chunked encoding trumps Content-Length. See RFC 2616
// Section 4.4. Currently len(te) > 0 implies chunked // Section 4.4. Currently len(te) > 0 implies chunked
// encoding. // encoding.
header["Content-Length"] = "", false header["Content-Length"] = nil, false
return te, nil return te, nil
} }
...@@ -331,7 +331,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) { ...@@ -331,7 +331,7 @@ func fixTransferEncoding(header map[string]string) ([]string, os.Error) {
// Determine the expected body length, using RFC 2616 Section 4.4. This // Determine the expected body length, using RFC 2616 Section 4.4. This
// function is not a method, because ultimately it should be shared by // function is not a method, because ultimately it should be shared by
// ReadResponse and ReadRequest. // ReadResponse and ReadRequest.
func fixLength(status int, requestMethod string, header map[string]string, te []string) (int64, os.Error) { func fixLength(status int, requestMethod string, header Header, te []string) (int64, os.Error) {
// Logic based on response type or status // Logic based on response type or status
if noBodyExpected(requestMethod) { if noBodyExpected(requestMethod) {
...@@ -351,8 +351,7 @@ func fixLength(status int, requestMethod string, header map[string]string, te [] ...@@ -351,8 +351,7 @@ func fixLength(status int, requestMethod string, header map[string]string, te []
} }
// Logic based on Content-Length // Logic based on Content-Length
if cl, present := header["Content-Length"]; present { cl := strings.TrimSpace(header.Get("Content-Length"))
cl = strings.TrimSpace(cl)
if cl != "" { if cl != "" {
n, err := strconv.Atoi64(cl) n, err := strconv.Atoi64(cl)
if err != nil || n < 0 { if err != nil || n < 0 {
...@@ -360,14 +359,13 @@ func fixLength(status int, requestMethod string, header map[string]string, te [] ...@@ -360,14 +359,13 @@ func fixLength(status int, requestMethod string, header map[string]string, te []
} }
return n, nil return n, nil
} else { } else {
header["Content-Length"] = "", false header.Del("Content-Length")
}
} }
// Logic based on media type. The purpose of the following code is just // Logic based on media type. The purpose of the following code is just
// to detect whether the unsupported "multipart/byteranges" is being // to detect whether the unsupported "multipart/byteranges" is being
// used. A proper Content-Type parser is needed in the future. // used. A proper Content-Type parser is needed in the future.
if strings.Contains(strings.ToLower(header["Content-Type"]), "multipart/byteranges") { if strings.Contains(strings.ToLower(header.Get("Content-Type")), "multipart/byteranges") {
return -1, ErrNotSupported return -1, ErrNotSupported
} }
...@@ -378,24 +376,19 @@ func fixLength(status int, requestMethod string, header map[string]string, te [] ...@@ -378,24 +376,19 @@ func fixLength(status int, requestMethod string, header map[string]string, te []
// Determine whether to hang up after sending a request and body, or // Determine whether to hang up after sending a request and body, or
// receiving a response and body // receiving a response and body
// 'header' is the request headers // 'header' is the request headers
func shouldClose(major, minor int, header map[string]string) bool { func shouldClose(major, minor int, header Header) bool {
if major < 1 { if major < 1 {
return true return true
} else if major == 1 && minor == 0 { } else if major == 1 && minor == 0 {
v, present := header["Connection"] if !strings.Contains(strings.ToLower(header.Get("Connection")), "keep-alive") {
if !present {
return true
}
v = strings.ToLower(v)
if !strings.Contains(v, "keep-alive") {
return true return true
} }
return false return false
} else if v, present := header["Connection"]; present { } else {
// TODO: Should split on commas, toss surrounding white space, // TODO: Should split on commas, toss surrounding white space,
// and check each field. // and check each field.
if v == "close" { if strings.ToLower(header.Get("Connection")) == "close" {
header["Connection"] = "", false header.Del("Connection")
return true return true
} }
} }
...@@ -403,14 +396,14 @@ func shouldClose(major, minor int, header map[string]string) bool { ...@@ -403,14 +396,14 @@ func shouldClose(major, minor int, header map[string]string) bool {
} }
// Parse the trailer header // Parse the trailer header
func fixTrailer(header map[string]string, te []string) (map[string]string, os.Error) { func fixTrailer(header Header, te []string) (Header, os.Error) {
raw, present := header["Trailer"] raw := header.Get("Trailer")
if !present { if raw == "" {
return nil, nil return nil, nil
} }
header["Trailer"] = "", false header.Del("Trailer")
trailer := make(map[string]string) trailer := make(Header)
keys := strings.Split(raw, ",", -1) keys := strings.Split(raw, ",", -1)
for _, key := range keys { for _, key := range keys {
key = CanonicalHeaderKey(strings.TrimSpace(key)) key = CanonicalHeaderKey(strings.TrimSpace(key))
...@@ -418,7 +411,7 @@ func fixTrailer(header map[string]string, te []string) (map[string]string, os.Er ...@@ -418,7 +411,7 @@ func fixTrailer(header map[string]string, te []string) (map[string]string, os.Er
case "Transfer-Encoding", "Trailer", "Content-Length": case "Transfer-Encoding", "Trailer", "Content-Length":
return nil, &badStringError{"bad trailer key", key} return nil, &badStringError{"bad trailer key", key}
} }
trailer[key] = "" trailer.Del(key)
} }
if len(trailer) == 0 { if len(trailer) == 0 {
return nil, nil return nil, nil
......
...@@ -6,6 +6,7 @@ include ../../../Make.inc ...@@ -6,6 +6,7 @@ include ../../../Make.inc
TARG=net/textproto TARG=net/textproto
GOFILES=\ GOFILES=\
header.go\
pipeline.go\ pipeline.go\
reader.go\ reader.go\
textproto.go\ textproto.go\
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package textproto
// A MIMEHeader represents a MIME-style header mapping
// keys to sets of values.
type MIMEHeader map[string][]string
// Add adds the key, value pair to the header.
// It appends to any existing values associated with key.
func (h MIMEHeader) Add(key, value string) {
key = CanonicalMIMEHeaderKey(key)
h[key] = append(h[key], value)
}
// Set sets the header entries associated with key to
// the single element value. It replaces any existing
// values associated with key.
func (h MIMEHeader) Set(key, value string) {
h[CanonicalMIMEHeaderKey(key)] = []string{value}
}
// Get gets the first value associated with the given key.
// If there are no values associated with the key, Get returns "".
// Get is a convenience method. For more complex queries,
// access the map directly.
func (h MIMEHeader) Get(key string) string {
if h == nil {
return ""
}
v := h[CanonicalMIMEHeaderKey(key)]
if len(v) == 0 {
return ""
}
return v[0]
}
// Del deletes the values associated with key.
func (h MIMEHeader) Del(key string) {
h[CanonicalMIMEHeaderKey(key)] = nil, false
}
...@@ -402,7 +402,7 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) { ...@@ -402,7 +402,7 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) {
// ReadMIMEHeader reads a MIME-style header from r. // ReadMIMEHeader reads a MIME-style header from r.
// The header is a sequence of possibly continued Key: Value lines // The header is a sequence of possibly continued Key: Value lines
// ending in a blank line. // ending in a blank line.
// The returned map m maps CanonicalHeaderKey(key) to a // The returned map m maps CanonicalMIMEHeaderKey(key) to a
// sequence of values in the same order encountered in the input. // sequence of values in the same order encountered in the input.
// //
// For example, consider this input: // For example, consider this input:
...@@ -415,12 +415,12 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) { ...@@ -415,12 +415,12 @@ func (r *Reader) ReadDotLines() ([]string, os.Error) {
// Given that input, ReadMIMEHeader returns the map: // Given that input, ReadMIMEHeader returns the map:
// //
// map[string][]string{ // map[string][]string{
// "My-Key": []string{"Value 1", "Value 2"}, // "My-Key": {"Value 1", "Value 2"},
// "Long-Key": []string{"Even Longer Value"}, // "Long-Key": {"Even Longer Value"},
// } // }
// //
func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) { func (r *Reader) ReadMIMEHeader() (MIMEHeader, os.Error) {
m := make(map[string][]string) m := make(MIMEHeader)
for { for {
kv, err := r.ReadContinuedLineBytes() kv, err := r.ReadContinuedLineBytes()
if len(kv) == 0 { if len(kv) == 0 {
...@@ -432,7 +432,7 @@ func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) { ...@@ -432,7 +432,7 @@ func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) {
if i < 0 || bytes.IndexByte(kv[0:i], ' ') >= 0 { if i < 0 || bytes.IndexByte(kv[0:i], ' ') >= 0 {
return m, ProtocolError("malformed MIME header line: " + string(kv)) return m, ProtocolError("malformed MIME header line: " + string(kv))
} }
key := CanonicalHeaderKey(string(kv[0:i])) key := CanonicalMIMEHeaderKey(string(kv[0:i]))
// Skip initial spaces in value. // Skip initial spaces in value.
i++ // skip colon i++ // skip colon
...@@ -452,12 +452,12 @@ func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) { ...@@ -452,12 +452,12 @@ func (r *Reader) ReadMIMEHeader() (map[string][]string, os.Error) {
panic("unreachable") panic("unreachable")
} }
// CanonicalHeaderKey returns the canonical format of the // CanonicalMIMEHeaderKey returns the canonical format of the
// MIME header key s. The canonicalization converts the first // MIME header key s. The canonicalization converts the first
// letter and any letter following a hyphen to upper case; // letter and any letter following a hyphen to upper case;
// the rest are converted to lowercase. For example, the // the rest are converted to lowercase. For example, the
// canonical key for "accept-encoding" is "Accept-Encoding". // canonical key for "accept-encoding" is "Accept-Encoding".
func CanonicalHeaderKey(s string) string { func CanonicalMIMEHeaderKey(s string) string {
// Quick check for canonical encoding. // Quick check for canonical encoding.
needUpper := true needUpper := true
for i := 0; i < len(s); i++ { for i := 0; i < len(s); i++ {
......
...@@ -26,10 +26,10 @@ var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ ...@@ -26,10 +26,10 @@ var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{
{"USER-AGENT", "User-Agent"}, {"USER-AGENT", "User-Agent"},
} }
func TestCanonicalHeaderKey(t *testing.T) { func TestCanonicalMIMEHeaderKey(t *testing.T) {
for _, tt := range canonicalHeaderKeyTests { for _, tt := range canonicalHeaderKeyTests {
if s := CanonicalHeaderKey(tt.in); s != tt.out { if s := CanonicalMIMEHeaderKey(tt.in); s != tt.out {
t.Errorf("CanonicalHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) t.Errorf("CanonicalMIMEHeaderKey(%q) = %q, want %q", tt.in, s, tt.out)
} }
} }
} }
...@@ -130,7 +130,7 @@ func TestReadDotBytes(t *testing.T) { ...@@ -130,7 +130,7 @@ func TestReadDotBytes(t *testing.T) {
func TestReadMIMEHeader(t *testing.T) { func TestReadMIMEHeader(t *testing.T) {
r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n")
m, err := r.ReadMIMEHeader() m, err := r.ReadMIMEHeader()
want := map[string][]string{ want := MIMEHeader{
"My-Key": {"Value 1", "Value 2"}, "My-Key": {"Value 1", "Value 2"},
"Long-Key": {"Even Longer Value"}, "Long-Key": {"Even Longer Value"},
} }
......
...@@ -245,20 +245,20 @@ func handshake(resourceName, host, origin, location, protocol string, br *bufio. ...@@ -245,20 +245,20 @@ func handshake(resourceName, host, origin, location, protocol string, br *bufio.
} }
// Step 41. check websocket headers. // Step 41. check websocket headers.
if resp.Header["Upgrade"] != "WebSocket" || if resp.Header.Get("Upgrade") != "WebSocket" ||
strings.ToLower(resp.Header["Connection"]) != "upgrade" { strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade return ErrBadUpgrade
} }
if resp.Header["Sec-Websocket-Origin"] != origin { if resp.Header.Get("Sec-Websocket-Origin") != origin {
return ErrBadWebSocketOrigin return ErrBadWebSocketOrigin
} }
if resp.Header["Sec-Websocket-Location"] != location { if resp.Header.Get("Sec-Websocket-Location") != location {
return ErrBadWebSocketLocation return ErrBadWebSocketLocation
} }
if protocol != "" && resp.Header["Sec-Websocket-Protocol"] != protocol { if protocol != "" && resp.Header.Get("Sec-Websocket-Protocol") != protocol {
return ErrBadWebSocketProtocol return ErrBadWebSocketProtocol
} }
...@@ -304,17 +304,17 @@ func draft75handshake(resourceName, host, origin, location, protocol string, br ...@@ -304,17 +304,17 @@ func draft75handshake(resourceName, host, origin, location, protocol string, br
if resp.Status != "101 Web Socket Protocol Handshake" { if resp.Status != "101 Web Socket Protocol Handshake" {
return ErrBadStatus return ErrBadStatus
} }
if resp.Header["Upgrade"] != "WebSocket" || if resp.Header.Get("Upgrade") != "WebSocket" ||
resp.Header["Connection"] != "Upgrade" { resp.Header.Get("Connection") != "Upgrade" {
return ErrBadUpgrade return ErrBadUpgrade
} }
if resp.Header["Websocket-Origin"] != origin { if resp.Header.Get("Websocket-Origin") != origin {
return ErrBadWebSocketOrigin return ErrBadWebSocketOrigin
} }
if resp.Header["Websocket-Location"] != location { if resp.Header.Get("Websocket-Location") != location {
return ErrBadWebSocketLocation return ErrBadWebSocketLocation
} }
if protocol != "" && resp.Header["Websocket-Protocol"] != protocol { if protocol != "" && resp.Header.Get("Websocket-Protocol") != protocol {
return ErrBadWebSocketProtocol return ErrBadWebSocketProtocol
} }
return return
......
...@@ -73,23 +73,23 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -73,23 +73,23 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
// HTTP version can be safely ignored. // HTTP version can be safely ignored.
if strings.ToLower(req.Header["Upgrade"]) != "websocket" || if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(req.Header["Connection"]) != "upgrade" { strings.ToLower(req.Header.Get("Connection")) != "upgrade" {
return return
} }
// TODO(ukai): check Host // TODO(ukai): check Host
origin, found := req.Header["Origin"] origin := req.Header.Get("Origin")
if !found { if origin == "" {
return return
} }
key1, found := req.Header["Sec-Websocket-Key1"] key1 := req.Header.Get("Sec-Websocket-Key1")
if !found { if key1 == "" {
return return
} }
key2, found := req.Header["Sec-Websocket-Key2"] key2 := req.Header.Get("Sec-Websocket-Key2")
if !found { if key2 == "" {
return return
} }
key3 := make([]byte, 8) key3 := make([]byte, 8)
...@@ -138,8 +138,8 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -138,8 +138,8 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n") buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n")
buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n") buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n")
protocol, found := req.Header["Sec-Websocket-Protocol"] protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if found { if protocol != "" {
buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n") buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n")
} }
// Step 12. send CRLF. // Step 12. send CRLF.
...@@ -167,18 +167,18 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -167,18 +167,18 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, "Unexpected request") io.WriteString(w, "Unexpected request")
return return
} }
if req.Header["Upgrade"] != "WebSocket" { if req.Header.Get("Upgrade") != "WebSocket" {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Upgrade: WebSocket header") io.WriteString(w, "missing Upgrade: WebSocket header")
return return
} }
if req.Header["Connection"] != "Upgrade" { if req.Header.Get("Connection") != "Upgrade" {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Connection: Upgrade header") io.WriteString(w, "missing Connection: Upgrade header")
return return
} }
origin, found := req.Header["Origin"] origin := strings.TrimSpace(req.Header.Get("Origin"))
if !found { if origin == "" {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "missing Origin header") io.WriteString(w, "missing Origin header")
return return
...@@ -205,9 +205,9 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { ...@@ -205,9 +205,9 @@ func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("WebSocket-Origin: " + origin + "\r\n") buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
buf.WriteString("WebSocket-Location: " + location + "\r\n") buf.WriteString("WebSocket-Location: " + location + "\r\n")
protocol, found := req.Header["Websocket-Protocol"] protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol"))
// canonical header key of WebSocket-Protocol. // canonical header key of WebSocket-Protocol.
if found { if protocol != "" {
buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
} }
buf.WriteString("\r\n") buf.WriteString("\r\n")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment