Fixed tests

This commit is contained in:
Jae Kwon
2014-12-17 01:37:13 -08:00
parent 6cacf6f09b
commit 61d1635085
40 changed files with 781 additions and 592 deletions

View File

@ -1,6 +1,7 @@
package binary
import (
"errors"
"fmt"
"io"
"reflect"
@ -10,6 +11,9 @@ type TypeInfo struct {
Type reflect.Type // The type
Encoder Encoder // Optional custom encoder function
Decoder Decoder // Optional custom decoder function
HasTypeByte bool
TypeByte byte
}
// If a type implements TypeByte, the byte is included
@ -25,18 +29,33 @@ type HasTypeByte interface {
var typeInfos = map[reflect.Type]*TypeInfo{}
func RegisterType(info *TypeInfo) bool {
func RegisterType(info *TypeInfo) *TypeInfo {
// Register the type info
typeInfos[info.Type] = info
// Also register the underlying struct's info, if info.Type is a pointer.
// Or, if info.Type is not a pointer, register the pointer.
if info.Type.Kind() == reflect.Ptr {
rt := info.Type.Elem()
typeInfos[rt] = info
} else {
ptrRt := reflect.PtrTo(info.Type)
typeInfos[ptrRt] = info
}
return true
// See if the type implements HasTypeByte
if info.Type.Implements(reflect.TypeOf((*HasTypeByte)(nil)).Elem()) {
zero := reflect.Zero(info.Type)
typeByte := zero.Interface().(HasTypeByte).TypeByte()
if info.HasTypeByte && info.TypeByte != typeByte {
panic(fmt.Sprintf("Type %v expected TypeByte of %X", info.Type, typeByte))
}
info.HasTypeByte = true
info.TypeByte = typeByte
}
return info
}
func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *error) {
@ -54,15 +73,29 @@ func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *
rv, rt = rv.Elem(), rt.Elem()
}
// Custom decoder
// Get typeInfo
typeInfo := typeInfos[rt]
if typeInfo != nil && typeInfo.Decoder != nil {
if typeInfo == nil {
typeInfo = RegisterType(&TypeInfo{Type: rt})
}
// Custom decoder
if typeInfo.Decoder != nil {
decoded := typeInfo.Decoder(r, n, err)
decodedRv := reflect.Indirect(reflect.ValueOf(decoded))
rv.Set(decodedRv)
return
}
// Read TypeByte prefix
if typeInfo.HasTypeByte {
typeByte := ReadByte(r, n, err)
if typeByte != typeInfo.TypeByte {
*err = errors.New(fmt.Sprintf("Expected TypeByte of %X but got %X", typeInfo.TypeByte, typeByte))
return
}
}
switch rt.Kind() {
case reflect.Slice:
elemRt := rt.Elem()
@ -86,7 +119,7 @@ func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
if field.Anonymous {
if field.PkgPath != "" {
continue
}
fieldRv := rv.Field(i)
@ -97,6 +130,26 @@ func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *
str := ReadString(r, n, err)
rv.SetString(str)
case reflect.Int64:
num := ReadUInt64(r, n, err)
rv.SetInt(int64(num))
case reflect.Int32:
num := ReadUInt32(r, n, err)
rv.SetInt(int64(num))
case reflect.Int16:
num := ReadUInt16(r, n, err)
rv.SetInt(int64(num))
case reflect.Int8:
num := ReadUInt8(r, n, err)
rv.SetInt(int64(num))
case reflect.Int:
num := ReadUVarInt(r, n, err)
rv.SetInt(int64(num))
case reflect.Uint64:
num := ReadUInt64(r, n, err)
rv.SetUint(uint64(num))
@ -124,9 +177,14 @@ func readReflect(rv reflect.Value, rt reflect.Type, r io.Reader, n *int64, err *
func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err *error) {
// Custom encoder
// Get typeInfo
typeInfo := typeInfos[rt]
if typeInfo != nil && typeInfo.Encoder != nil {
if typeInfo == nil {
typeInfo = RegisterType(&TypeInfo{Type: rt})
}
// Custom encoder, say for an interface type rt.
if typeInfo.Encoder != nil {
typeInfo.Encoder(rv.Interface(), w, n, err)
return
}
@ -135,14 +193,21 @@ func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
rv = rv.Elem()
// RegisterType registers the ptr type,
// so typeInfo is already for the ptr.
} else if rt.Kind() == reflect.Interface {
rv = rv.Elem()
rt = rv.Type()
typeInfo = typeInfos[rt]
// If interface type, get typeInfo of underlying type.
if typeInfo == nil {
typeInfo = RegisterType(&TypeInfo{Type: rt})
}
}
// Write TypeByte prefix
if rt.Implements(reflect.TypeOf((*HasTypeByte)(nil)).Elem()) {
WriteByte(rv.Interface().(HasTypeByte).TypeByte(), w, n, err)
if typeInfo.HasTypeByte {
WriteByte(typeInfo.TypeByte, w, n, err)
}
switch rt.Kind() {
@ -167,7 +232,7 @@ func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err
numFields := rt.NumField()
for i := 0; i < numFields; i++ {
field := rt.Field(i)
if field.Anonymous {
if field.PkgPath != "" {
continue
}
fieldRv := rv.Field(i)
@ -177,6 +242,21 @@ func writeReflect(rv reflect.Value, rt reflect.Type, w io.Writer, n *int64, err
case reflect.String:
WriteString(rv.String(), w, n, err)
case reflect.Int64:
WriteInt64(rv.Int(), w, n, err)
case reflect.Int32:
WriteInt32(int32(rv.Int()), w, n, err)
case reflect.Int16:
WriteInt16(int16(rv.Int()), w, n, err)
case reflect.Int8:
WriteInt8(int8(rv.Int()), w, n, err)
case reflect.Int:
WriteVarInt(int(rv.Int()), w, n, err)
case reflect.Uint64:
WriteUInt64(rv.Uint(), w, n, err)