Commit c4ba3f34 authored by Kirill Smelkov's avatar Kirill Smelkov

.

parent 85666824
This diff is collapsed.
...@@ -91,11 +91,13 @@ var ErrDecodeOverflow = errors.New("decode: bufer overflow") ...@@ -91,11 +91,13 @@ var ErrDecodeOverflow = errors.New("decode: bufer overflow")
// NEOEncoder is interface for marshaling objects to wire format // NEOEncoder is interface for marshaling objects to wire format
type NEOEncoder interface { type NEOEncoder interface {
// compute how much space is needed to encode // NEOEncodedInfo returns message code needed to be used for the packet
NEOEncodedLen() int // on the wire and how much space is needed to encode payload
// XXX naming?
NEOEncodedInfo() (msgCode uint16, payloadLen int)
// perform the encoding. // perform the encoding.
// len(buf) must be >= NEOEncodedLen() // len(buf) must be >= payloadLen returned by NEOEncodedInfo
NEOEncode(buf []byte) NEOEncode(buf []byte)
} }
......
...@@ -82,7 +82,11 @@ func testPktMarshal(t *testing.T, pkt NEOCodec, encoded string) { ...@@ -82,7 +82,11 @@ func testPktMarshal(t *testing.T, pkt NEOCodec, encoded string) {
}() }()
// pkt.encode() == expected // pkt.encode() == expected
n := pkt.NEOEncodedLen() msgCode, n := pkt.NEOEncodedInfo()
msgType := pktTypeRegistry[msgCode]
if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType)
}
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded)) t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded))
} }
...@@ -266,7 +270,7 @@ func TestPktMarshalAllOverflowLightly(t *testing.T) { ...@@ -266,7 +270,7 @@ func TestPktMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range pktTypeRegistry { for _, typ := range pktTypeRegistry {
// zero-value for a type // zero-value for a type
pkt := reflect.New(typ).Interface().(NEOCodec) pkt := reflect.New(typ).Interface().(NEOCodec)
l := pkt.NEOEncodedLen() _, l := pkt.NEOEncodedInfo()
zerol := make([]byte, l) zerol := make([]byte, l)
// decoding will turn nil slice & map into empty allocated ones. // decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for pkt encode/decode comparison // we need it so that reflect.DeepEqual works for pkt encode/decode comparison
......
...@@ -19,7 +19,7 @@ This program generates marshalling code for packet types defined in proto.go . ...@@ -19,7 +19,7 @@ This program generates marshalling code for packet types defined in proto.go .
For every type 3 methods are generated in accordance with NEOEncoder and For every type 3 methods are generated in accordance with NEOEncoder and
NEODecoder interfaces: NEODecoder interfaces:
NEOEncodedLen() int NEOEncodedInfo() (msgCode uint16, payloadLen int)
NEOEncode(buf []byte) NEOEncode(buf []byte)
NEODecode(data []byte) (nread int, err error) NEODecode(data []byte) (nread int, err error)
...@@ -192,7 +192,7 @@ import ( ...@@ -192,7 +192,7 @@ import (
case *ast.StructType: case *ast.StructType:
fmt.Fprintf(&buf, "// %d. %s\n\n", pktCode, typename) fmt.Fprintf(&buf, "// %d. %s\n\n", pktCode, typename)
buf.WriteString(generateCodecCode(typespec, &sizer{})) buf.WriteString(generateCodecCode(typespec, &sizer{msgCode: pktCode}))
buf.WriteString(generateCodecCode(typespec, &encoder{})) buf.WriteString(generateCodecCode(typespec, &encoder{}))
buf.WriteString(generateCodecCode(typespec, &decoder{})) buf.WriteString(generateCodecCode(typespec, &decoder{}))
...@@ -455,13 +455,18 @@ func (o *OverflowCheck) AddExpr(format string, a ...interface{}) { ...@@ -455,13 +455,18 @@ func (o *OverflowCheck) AddExpr(format string, a ...interface{}) {
type sizer struct { type sizer struct {
commonCodeGen commonCodeGen
size SymSize // currently accumulated packet size size SymSize // currently accumulated packet size
// which code to also return as packet msgCode
// (sizer does not compute this - it is emitted as-is given by caller)
msgCode int
} }
// encoder generates code to encode a packet // encoder generates code to encode a packet
// //
// when type is recursively walked, for every case code to update `data[n:]` is generated. // when type is recursively walked, for every case code to update `data[n:]` is generated.
// no overflow checks are generated as by NEOEncoder interface provided data // no overflow checks are generated as by NEOEncoder interface provided data
// buffer should have at least NEOEncodedLen() length (the size computed by sizer). // buffer should have at least payloadLen length returned by NEOEncodedInfo()
// (the size computed by sizer).
// //
// the code emitted looks like: // the code emitted looks like:
// //
...@@ -517,7 +522,7 @@ var _ CodeGenerator = (*decoder)(nil) ...@@ -517,7 +522,7 @@ var _ CodeGenerator = (*decoder)(nil)
func (s *sizer) generatedCode() string { func (s *sizer) generatedCode() string {
code := Buffer{} code := Buffer{}
// prologue // prologue
code.emit("func (%s *%s) NEOEncodedLen() int {", s.recvName, s.typeName) code.emit("func (%s *%s) NEOEncodedInfo() (uint16, int) {", s.recvName, s.typeName)
if s.varUsed["size"] { if s.varUsed["size"] {
code.emit("var %s int", s.var_("size")) code.emit("var %s int", s.var_("size"))
} }
...@@ -529,7 +534,7 @@ func (s *sizer) generatedCode() string { ...@@ -529,7 +534,7 @@ func (s *sizer) generatedCode() string {
if s.varUsed["size"] { if s.varUsed["size"] {
size += " + " + s.var_("size") size += " + " + s.var_("size")
} }
code.emit("return %v", size) code.emit("return %v, %v", s.msgCode, size)
code.emit("}\n") code.emit("}\n")
return code.String() return code.String()
......
...@@ -126,8 +126,13 @@ func RecvAndDecode(conn *Conn) (interface{}, error) { // XXX interface{} -> NEOD ...@@ -126,8 +126,13 @@ func RecvAndDecode(conn *Conn) (interface{}, error) { // XXX interface{} -> NEOD
func EncodeAndSend(conn *Conn, pkt NEOEncoder) error { func EncodeAndSend(conn *Conn, pkt NEOEncoder) error {
// TODO encode pkt // TODO encode pkt
l := pkt.NEOEncodedLen() msgCode, l := pkt.NEOEncodedInfo()
buf := PktBuf{make([]byte, PktHeadLen + l)} // XXX -> freelist l += PktHeadLen
buf := PktBuf{make([]byte, l)} // XXX -> freelist
h := buf.Header()
h.MsgCode = hton16(msgCode)
h.Len = hton32(uint32(l)) // XXX casting: think again
pkt.NEOEncode(buf.Payload()) pkt.NEOEncode(buf.Payload())
return conn.Send(&buf) // XXX why pointer? return conn.Send(&buf) // XXX why pointer?
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment