Commit 4178bf4a authored by Kirill Smelkov's avatar Kirill Smelkov

Add custom Dict that mirrors Python dict behaviour

Ogórek currently represents unpickled dict via map[any]any, which is
logical, but also exhibits issues because builtin Go map behaviour is
different from Python's dict behaviour. For example:

- Python's dict allows tuples to be used in keys, while Go map
  does not (https://github.com/kisielk/og-rek/issues/50),

- Python's dict allows both long and int to be used interchangeable as
  keys, while Go map does not handle *big.Int as key with the same
  semantic (https://github.com/kisielk/og-rek/issues/55)

- Python's dict allows to use numbers interchangeable in keys - all int
  and float, but on Go side int(1) and float64(1.0) are considered by
  builtin map as different keys.

- In Python world bytestring (str from py2) is considered to be related
  to both unicode (str on py3) and bytes, but builtin map considers all
  string, Bytes and ByteString as different keys.

- etc...

All in all there are many differences in behaviour in builtin Python
dict and Go map that result in generally different semantics when
decoding pickled data. Those differences can be fixed only if we add
custom dict implementation that mirrors what Python does.

-> Do that: add custom Dict that implements key -> value mapping with
   mirroring Python behaviour.

For now we are only adding the Dict class itself and its tests.
Later we will use this new Dict to handle decoding dictionaries from the pickles.

For the implementation we use github.com/aristanetworks/gomap which
provides extraction of builtin go map code wrapped into generic type
Map[Key,Value] that accepts custom equal and hash functions. And it is
those equal and hash functions via which we define equality behaviour to
be the same as on Python side.

Please see documentation in the added code and tests for details.
parent 64d5926e
package ogórek
// Python-like Dict that handles keys by Python-like equality on access.
//
// For example Dict.Get() will access the same element for all keys int(1), float64(1.0) and big.Int(1).
import (
"encoding/binary"
"fmt"
"hash/maphash"
"math"
"math/big"
"reflect"
"sort"
"github.com/aristanetworks/gomap"
)
// Dict represents dict from Python.
//
// It mirrors Python with respect to which types are allowed to be used as
// keys, and with respect to keys equality. For example Tuple is allowed to be
// used as key, and all int(1), float64(1.0) and big.Int(1) are considered to be
// equal.
//
// For strings, similarly to Python3, Bytes and string are considered to be not
// equal, even if their underlying content is the same. However with same
// underlying content ByteString, because it represents str type from Python2,
// is treated equal to both Bytes and string.
//
// Note: similarly to builtin map Dict is pointer-like type: its zero-value
// represents nil dictionary that is empty and invalid to use Set on.
type Dict struct {
m *gomap.Map[any, any]
}
// NewDict returns new empty dictionary.
func NewDict() Dict {
return NewDictWithSizeHint(0)
}
// NewDictWithSizeHint returns new empty dictionary with preallocated space for size items.
func NewDictWithSizeHint(size int) Dict {
return Dict{m: gomap.NewHint[any, any](size, equal, hash)}
}
// NewDictWithData returns new dictionary with preset data.
//
// kv should be key₁, value₁, key₂, value₂, ...
func NewDictWithData(kv ...any) Dict {
l := len(kv)
if l % 2 != 0 {
panic("odd number of arguments")
}
l /= 2
d := NewDictWithSizeHint(l)
for i := 0; i < l; i++ {
k := kv[2*i]
v := kv[2*i+1]
d.Set(k, v)
}
return d
}
// Get returns value associated with equal key.
//
// An entry with key equal to the query is looked up and corresponding value
// is returned.
//
// nil is returned if no matching key is present in the dictionary.
//
// Get panics if key's type is not allowed to be used as Dict key.
func (d Dict) Get(key any) any {
value, _ := d.Get_(key)
return value
}
// Get_ is comma-ok version of Get.
func (d Dict) Get_(key any) (value any, ok bool) {
return d.m.Get(key)
}
// Set sets key to be associated with value.
//
// Any previous keys, equal to the new key, are removed from the dictionary
// before the assignment.
//
// Set panics if key's type is not allowed to be used as Dict key.
func (d Dict) Set(key, value any) {
// ByteString and container(with ByteString) are non-transitive equal types
// so Set(ByteString) should first remove Bytes and string,
// and Set(Tuple{ByteString) should first remove Tuple{Bytes} and Tuple{string}
d.Del(key)
d.m.Set(key, value)
}
// Del removes equal keys from the dictionary.
//
// All entries with key equal to the query are looked up and removed.
//
// Del panics if key's type is not allowed to be used as Dict key.
func (d Dict) Del(key any) {
// see comment in Set about ByteString and container(with ByteString)
for {
d.m.Delete(key)
_, have := d.Get_(key)
if !have {
break
}
}
}
// Len returns the number of items in the dictionary.
func (d Dict) Len() int {
return d.m.Len()
}
// Iter returns iterator over all elements in the dictionary.
//
// The order to visit entries is arbitrary.
func (d Dict) Iter() /* iter.Seq2 */ func(yield func(any, any) bool) {
it := d.m.Iter()
return func(yield func(any, any) bool) {
for it.Next() {
cont := yield(it.Key(), it.Elem())
if !cont {
break
}
}
}
}
// String returns human-readable representation of the dictionary.
func (d Dict) String() string {
return d.sprintf("%v")
}
// GoString returns detailed human-readable representation of the dictionary.
func (d Dict) GoString() string {
return fmt.Sprintf("%T%s", d, d.sprintf("%#v"))
}
// sprintf serves String and GoString.
func (d Dict) sprintf(format string) string {
type KV struct { k,v string }
vkv := make([]KV, 0, d.Len())
d.Iter()(func(k, v any) bool {
vkv = append(vkv, KV{
k: fmt.Sprintf(format, k),
v: fmt.Sprintf(format, v),
})
return true
})
sort.Slice(vkv, func(i, j int) bool {
return vkv[i].k < vkv[j].k
})
s := "{"
for i, kv := range vkv {
if i > 0 {
s += ", "
}
s += kv.k + ": " + kv.v
}
s += "}"
return s
}
// ---- equal ----
// kind represents to which category a type belongs.
//
// It primarily classifies bool, numbers, slices, structs and maps, and puts
// everything else into "other" category.
type kind uint
const (
kBool = iota
kInt // int + intX
kUint // uint + uintX
kFloat // floatX
kComplex // complexX
kBigInt // *big.Int
kSlice // slice + array
kMap // map
kStruct // struct
kPointer // pointer
kOther // everything else
)
// kindOf returns kind of x.
func kindOf(x any) kind {
r := reflect.ValueOf(x)
switch r.Kind() {
case reflect.Bool:
return kBool
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return kInt
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return kUint
case reflect.Float64, reflect.Float32:
return kFloat
case reflect.Complex128, reflect.Complex64:
return kComplex
case reflect.Slice, reflect.Array:
return kSlice
case reflect.Map:
return kMap
case reflect.Struct:
return kStruct
}
switch x.(type) {
case *big.Int:
return kBigInt
}
switch r.Kind() {
case reflect.Pointer:
return kPointer
}
return kOther
}
// equal implements equality matching what Python would return for a == b.
//
// Equality properties:
//
// 1) equality is extension of Go ==
//
// (a == b) ⇒ equal(a,b)
//
// 2) self equal:
//
// equal(a,a) = y
//
// 3) equality is symmetrical:
//
// equal(a,b) = equal(b,a)
//
// 4) equality is mostly transitive:
//
// EqTransitive = set of all x:
// ∀ a,b,c ∈ EqTransitive:
// equal(a,b) ^ equal(b,c) ⇒ equal(a,c)
//
// EqTransitive = all \ {ByteString + containers with ByteString}
func equal(xa, xb any) bool {
// strings/bytes
switch a := xa.(type) {
case string:
switch b := xb.(type) {
case string: return a == b
case ByteString: return a == string(b)
case Bytes: return false
default: return false
}
case ByteString:
switch b := xb.(type) {
case string: return a == ByteString(b)
case ByteString: return a == b
case Bytes: return a == ByteString(b)
default: return false
}
case Bytes:
switch b := xb.(type) {
case string: return false
case ByteString: return a == Bytes(b)
case Bytes: return a == b
default: return false
}
}
// everything else
a := reflect.ValueOf(xa)
b := reflect.ValueOf(xb)
ak := kindOf(xa)
bk := kindOf(xb)
// since equality is symmetric, we can implement only half of comparison matrix
if ak > bk {
a, b = b, a
ak, bk = bk, ak
xa, xb = xb, xa
}
// ak ≤ bk
handled := true
switch ak {
default:
handled = false
// numbers
case kBool:
// bool compares to numbers as 1 or 0
//
// In [1]: 1.0 == True
// Out[1]: True
//
// In [2]: 0.0 == False
// Out[2]: True
//
// In [3]: d = {1: 'abc'}
//
// In [4]: d[True]
// Out[4]: 'abc'
abint := bint(a.Bool())
switch bk {
case kBool: return eq_Int_Int (abint, bint(b.Bool()))
case kInt: return eq_Int_Int (abint, b.Int())
case kUint: return eq_Int_Uint (abint, b.Uint())
case kFloat: return eq_Int_Float (abint, b.Float())
case kComplex: return eq_Int_Complex (abint, b.Complex())
case kBigInt: return eq_Int_BigInt (abint, xb.(*big.Int))
}
case kInt:
aint := a.Int()
switch bk {
// kBool
case kInt: return eq_Int_Int (aint, b.Int())
case kUint: return eq_Int_Uint (aint, b.Uint())
case kFloat: return eq_Int_Float (aint, b.Float())
case kComplex: return eq_Int_Complex (aint, b.Complex())
case kBigInt: return eq_Int_BigInt (aint, xb.(*big.Int))
}
case kUint:
auint := a.Uint()
switch bk {
// kBool
// kInt
case kUint: return eq_Uint_Uint (auint, b.Uint())
case kFloat: return eq_Uint_Float (auint, b.Float())
case kComplex: return eq_Uint_Complex (auint, b.Complex())
case kBigInt: return eq_Uint_BigInt (auint, xb.(*big.Int))
}
case kFloat:
afloat := a.Float()
switch bk {
// kBool
// kInt
// kUint
case kFloat: return eq_Float_Float (afloat, b.Float())
case kComplex: return eq_Float_Complex (afloat, b.Complex())
case kBigInt: return eq_Float_BigInt (afloat, xb.(*big.Int))
}
case kComplex:
acomplex := a.Complex()
switch bk {
// kBool
// kInt
// kUint
// kFloat
case kComplex: return eq_Complex_Complex (acomplex, b.Complex())
case kBigInt: return eq_Complex_BigInt (acomplex, xb.(*big.Int))
}
case kBigInt:
switch bk {
// kBool
// kInt
// kUint
// kFloat
// kComplex
case kBigInt: return eq_BigInt_BigInt (xa.(*big.Int), xb.(*big.Int))
}
// slices
case kSlice:
switch bk {
case kSlice: return eq_Slice_Slice (a, b)
}
// builtin map
case kMap:
switch bk {
case kMap: return eq_Map_Map (a, b)
}
switch b := xb.(type) {
case Dict: return eq_Map_Dict (a, b)
}
}
if handled {
return false
}
// our types that need special handling
switch a := xa.(type) {
case Dict:
switch b := xb.(type) {
case Dict: return eq_Dict_Dict(a, b)
default: return false
}
}
// structs (also covers None, Class, Call etc...)
switch ak {
case kStruct:
switch bk {
case kStruct: return eq_Struct_Struct (a, b)
default: return false
}
}
return (xa == xb) // fallback to builtin equality
}
// equality matrix. nontrivial elements
func eq_Int_Uint(a int64, b uint64) bool {
if a >= 0 {
return uint64(a) == b
}
return false
}
func eq_Int_BigInt(a int64, b *big.Int) bool {
if b.IsInt64() {
return a == b.Int64()
}
return false
}
func eq_Uint_BigInt(a uint64, b *big.Int) bool {
if b.IsUint64() {
return a == b.Uint64()
}
return false
}
func eq_Float_BigInt(a float64, b *big.Int) bool {
bf, accuracy := bigInt_Float64(b)
if accuracy == big.Exact {
return a == bf
}
return false
}
func eq_Complex_BigInt(a complex128, b *big.Int) bool {
if imag(a) == 0 {
return eq_Float_BigInt(real(a), b)
}
return false
}
func eq_BigInt_BigInt(a, b *big.Int) bool {
return (a.Cmp(b) == 0)
}
func eq_Slice_Slice(a, b reflect.Value) bool {
al := a.Len()
bl := b.Len()
if al != bl {
return false
}
for i := 0; i < al; i++ {
if !equal(a.Index(i).Interface(), b.Index(i).Interface()) {
return false
}
}
return true
}
func eq_Struct_Struct(a, b reflect.Value) bool {
if a.Type() != b.Type() {
return false
}
typ := a.Type()
l := typ.NumField()
for i := 0; i < l; i++ {
af := a.Field(i)
bf := b.Field(i)
// .Interface() is not allowed if the field is private.
// Work around the protection via unsafe. We may need to switch
// to struct copy if it is not addressable because Addr() is
// used in the workaround. https://stackoverflow.com/a/43918797/9456786
ftyp := typ.Field(i)
if !ftyp.IsExported() {
if !af.CanAddr() {
// switch a to addressable copy
a_ := reflect.New(typ).Elem()
a_.Set(a)
a = a_
af = a.Field(i)
}
if !bf.CanAddr() {
// switch b to addressable copy
b_ := reflect.New(typ).Elem()
b_.Set(b)
b = b_
bf = b.Field(i)
}
af = reflect.NewAt(ftyp.Type, af.Addr().UnsafePointer()).Elem()
bf = reflect.NewAt(ftyp.Type, bf.Addr().UnsafePointer()).Elem()
}
if !equal(af.Interface(), bf.Interface()) {
return false
}
}
return true
}
func eq_Dict_Dict(a Dict, b Dict) bool {
// dicts D₁ and D₂ are considered equal if the following is true:
//
// - len(D₁) = len(D₂)
// - ∀ k ∈ D₁ equal(D₁[k], D₂[k]) = y
// - ∀ k ∈ D₂ equal(D₁[k], D₂[k]) = y
//
// this definition is reasonable and fast to implement without additional memory.
// Also if D₁ and D₂ have keys only from equal-transitive subset of all
// keys (i.e. anything without ByteString), it becomes equivalent to the
// following definition:
//
// - (k₁i, v₁i) is set of all key/values from D₁
// - (k₂j, v₂j) is set of all key/values from D₂
// - equal(D₁,D₂):
//
// ∃ 1-1 mapping in between i<->j: equal(k₁i, k₂j) ^ equal(v₁i, v₂j)
if a.Len() != b.Len() {
return false
}
eq := true
a.Iter()(func(k,va any) bool {
vb, ok := b.Get_(k)
if !ok || !equal(va, vb) {
eq = false
return false
}
return true
})
if !eq {
return false
}
b.Iter()(func(k,vb any) bool {
va, ok := a.Get_(k)
if !ok || !equal(va, vb) {
eq = false
return false
}
return true
})
return eq
}
// equal(Map, Dict) and equal(Map, Map) follow semantic of equal(Dict, Dict)
func eq_Map_Dict(a reflect.Value, b Dict) bool {
if a.Len() != b.Len() {
return false
}
aKeyType := a.Type().Key()
ai := a.MapRange()
for ai.Next() {
k := ai.Key().Interface()
va := ai.Value().Interface()
vb, ok := b.Get_(k)
if !ok || !equal(va, vb) {
return false
}
}
eq := true
b.Iter()(func(k,vb any) bool {
xk := reflect.ValueOf(k)
if !xk.Type().AssignableTo(aKeyType) {
eq = false
return false
}
xva := a.MapIndex(xk)
if !(xva.IsValid() && equal(xva.Interface(), vb)) {
eq = false
return false
}
return true
})
return eq
}
func eq_Map_Map(a reflect.Value, b reflect.Value) bool {
if a.Len() != b.Len() {
return false
}
aKeyType := a.Type().Key()
bKeyType := b.Type().Key()
ai := a.MapRange()
for ai.Next() {
k := ai.Key().Interface() // NOTE xk != ai.Key() because that might have type any
xk := reflect.ValueOf(k) // while xk has type of particular contained value
va := ai.Value().Interface()
if !xk.Type().AssignableTo(bKeyType) {
return false
}
xvb := b.MapIndex(xk)
if !(xvb.IsValid() && equal(va, xvb.Interface())) {
return false
}
}
bi := b.MapRange()
for bi.Next() {
k := bi.Key().Interface() // see ^^^
xk := reflect.ValueOf(k)
vb := bi.Value().Interface()
if !xk.Type().AssignableTo(aKeyType) {
return false
}
xva := a.MapIndex(xk)
if !(xva.IsValid() && equal(xva.Interface(), vb)) {
return false
}
}
return true
}
// equality matrix. trivial elements
func eq_Int_Int (a int64, b int64) bool { return a == b }
func eq_Int_Float (a int64, b float64) bool { return float64(a) == b }
func eq_Int_Complex (a int64, b complex128) bool { return complex(float64(a), 0) == b }
func eq_Uint_Uint (a uint64, b uint64) bool { return a == b }
func eq_Uint_Float (a uint64, b float64) bool { return float64(a) == b }
func eq_Uint_Complex (a uint64, b complex128) bool { return complex(float64(a), 0) == b }
func eq_Float_Float (a float64, b float64) bool { return a == b }
func eq_Float_Complex (a float64, b complex128) bool { return complex(a, 0) == b }
func eq_Complex_Complex (a complex128, b complex128) bool { return a == b }
// ---- hash ----
// hash returns hash of x consistent with equality implemented by equal.
//
// equal(a,b) ⇒ hash(a) = hash(b)
//
// hash panics with "unhashable type: ..." if x is not allowed to be used as Dict key.
func hash(seed maphash.Seed, x any) uint64 {
// strings/bytes use standard hash of string
switch v := x.(type) {
case string: return maphash_String(seed, v)
case ByteString: return maphash_String(seed, string(v))
case Bytes: return maphash_String(seed, string(v))
}
// for everything else we implement custom hashing ourselves to match equal
var h maphash.Hash
h.SetSeed(seed)
hash_Uint := func(u uint64) {
var b [8]byte
binary.BigEndian.PutUint64(b[:], u)
h.Write(b[:])
}
hash_Int := func(i int64) {
hash_Uint(uint64(i))
}
hash_Float := func(f float64) {
// if float is in int range and is integer number - hash it as integer
i := int64(f)
f_ := float64(i)
if f_ == f {
hash_Int(i)
// else use raw float64 bytes representation for hashing
} else {
hash_Uint(math.Float64bits(f))
}
}
// numbers
r := reflect.ValueOf(x)
k := kindOf(x)
handled := true
switch k {
default:
handled = false
case kBool: hash_Int(bint(r.Bool()))
case kInt: hash_Int(r.Int())
case kUint: hash_Uint(r.Uint())
case kFloat: hash_Float(r.Float())
case kComplex:
c := r.Complex()
hash_Float(real(c))
if imag(c) != 0 {
hash_Float(imag(c))
}
case kBigInt:
b := x.(*big.Int)
switch {
case b.IsInt64(): hash_Int(b.Int64())
case b.IsUint64(): hash_Uint(b.Uint64())
default:
f, accuracy := bigInt_Float64(b)
if accuracy == big.Exact {
hash_Float(f)
} else {
h.WriteString("bigInt")
h.Write(b.Bytes())
}
}
// kSlice - skip
// kStruct - skip
case kPointer: hash_Uint(uint64(r.Elem().UnsafeAddr()))
}
if handled {
return h.Sum64()
}
// tuple
switch v := x.(type) {
case Tuple:
h.WriteString("tuple")
for _, item := range v {
hash_Uint(hash(seed, item))
}
return h.Sum64()
}
// structs (also covers None, Class, Call etc)
switch k {
case kStruct:
// our types that are handled specially by equal
switch x.(type) {
case Dict:
goto unhashable
}
typ := r.Type()
h.WriteString(typ.Name())
l := typ.NumField()
for i := 0; i < l; i++ {
f := r.Field(i)
// .Interface() is not allowed if the field is private.
// Work it around via unsafe. See eq_Struct_Struct for details.
ftyp := typ.Field(i)
if !ftyp.IsExported() {
if !f.CanAddr() {
// switch r to addressable copy
r_ := reflect.New(typ).Elem()
r_.Set(r)
r = r_
f = r.Field(i)
}
f = reflect.NewAt(ftyp.Type, f.Addr().UnsafePointer()).Elem()
}
hash_Uint(hash(seed, f.Interface()))
}
return h.Sum64()
}
unhashable:
panic(fmt.Sprintf("unhashable type: %T", x))
}
// ---- misc ----
// bint returns int corresponding to bool.
//
// true -> 1
// false -> 0
func bint(x bool) int64 {
if x {
return 1
}
return 0
}
package ogórek
import (
"fmt"
"hash/maphash"
"reflect"
"strings"
"testing"
)
// tStructWithPrivate is used by tests to verify handing of struct with private fields.
type tStructWithPrivate struct {
x, y any
}
// TestEqual verifies equal and hash.
func TestEqual(t *testing.T) {
// tEqualSet represents tested set of values:
// ∀ a ∈ tEqualSet:
// ∀ b ∈ tEqualSet ⇒ equal(a,b) = y
// ∀ c ∉ tEqualSet ⇒ equal(a,c) = n
//
// Intersection in between different tEqualSets is mostly empty: such
// intersections can contain elements only from all \ EqTransitive, i.e. only ByteString.
type tAllEqual []any
// E is shortcut to create tEqualSet
E := func(v ...any) tAllEqual { return tAllEqual(v) }
// D and M are shortcuts to create Dict and map[any]any
D := NewDictWithData
type M = map[any]any
// i1 and i1_ are two integer variables equal to 1 but with different address
// obj and obj_ are similar equal structures located at different memory regions
i1 := 1; i1_ := 1
obj := &Class{"a","b"}; obj_ := &Class{"a","b"}
// testv is vector of all test-cases
testv := []tAllEqual {
// numbers
E(int(0),
int64(0), int32(0), int16(0), int8(0),
uint64(0), uint32(0), uint16(0), uint8(0),
bigInt("0"),
false,
float32 (0), float64 (0),
complex64(0), complex128(0)),
E(int(1),
int64 (1), int32(1), int16(1), int8(1),
uint64(1), uint32(1), uint16(1), uint8(1),
bigInt("1"),
true,
float32 (1), float64 (1),
complex64(1), complex128(1)),
E(int(-1),
int64(-1), int32(-1), int16(-1), int8(-1),
// NOTE no uintX because they ≥ 0 only
bigInt("-1"),
// NOTE no bool because it ∈ {0,1}
float32 (-1), float64 (-1),
complex64(-1), complex128(-1)),
// intX/uintX different range
E(int(0xff),
int64(0xff), int32(0xff), int16(0xff), // int8(overflow),
uint64(0xff), uint32(0xff), uint16(0xff), // uint8(overflow),
bigInt("255"),
bigInt("255"), // two different *big.Int instances
float32 (0xff), float64 (0xff),
complex64(0xff), complex128(0xff)),
E(int(-0x80),
int64(-0x80), int32(-0x80), int16(-0x80), int8(-0x80),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-128"),
float32 (-0x80), float64 (-0x80),
complex64(-0x80), complex128(-0x80)),
E(int(0xffff),
int64(0xffff), int32(0xffff), // int16(overflow), int8(overflow),
uint64(0xffff), uint32(0xffff), uint16(0xffff), // uint8(overflow),
bigInt("65535"),
float32 (0xffff), float64 (0xffff),
complex64(0xffff), complex128(0xffff)),
E(int(-0x8000),
int64(-0x8000), int32(-0x8000), int16(-0x8000), // int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-32768"),
float32 (-0x8000), float64 (-0x8000),
complex64(-0x8000), complex128(-0x8000)),
E(int(0xffffffff),
int64(0xffffffff), // int32(overflow), int16(overflow), int8(overflow),
uint64(0xffffffff), uint32(0xffffffff), // uint16(overflow), uint8(overflow),
bigInt("4294967295"),
/* float32 (precision loss), */ float64 (0xffffffff),
/* complex64(precision loss), */ complex128(0xffffffff)),
E(int(-0x80000000),
int64(-0x80000000), int32(-0x80000000), // int16(overflow), int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-2147483648"),
float32 (-0x80000000), float64 (-0x80000000),
complex64(-0x80000000), complex128(-0x80000000)),
E(// int(overflow),
// int64(overflow), int32(overflow), int16(overflow), int8(overflow),
uint64(0xffffffffffffffff), // uint32(overflow), uint16(overflow), uint8(overflow),
bigInt("18446744073709551615")),
// float32 (precision loss), float64 (precision loss),
// complex64(precision loss), complex128(precision loss)),
E(int(-0x8000000000000000),
int64(-0x8000000000000000), // int32(overflow), int16(overflow), int8(overflow),
//uint64(), uint32(), uint16(), uint8(), ≥ 0 only
bigInt("-9223372036854775808"),
float32 (-0x8000000000000000), float64 (-0x8000000000000000),
complex64(-0x8000000000000000), complex128(-0x8000000000000000)),
E(bigInt("1"+strings.Repeat("0",22)), float64(1e22), complex128(complex(1e22,0))),
E(complex64(complex(0,1)), complex128(complex(0,1))),
E(float64(1.25), float32(1.25), complex64(complex(1.25,0)), complex128(complex(1.25,0))),
// strings/bytes
E("", ByteString("")), E(ByteString(""), Bytes("")),
E("a", ByteString("a")), E(ByteString("a"), Bytes("a")),
E("мир", ByteString("мир")), E(ByteString("мир"), Bytes("мир")),
// none / empty tuple|list
E(None{}),
E(Tuple{}, []any{}),
// sequences
E([]int{}, []float32{}, []any{}, Tuple{}, [0]float64{}),
E([]int{1,2}, []float32{1,2}, []any{1,2}, Tuple{1,2}, [2]float64{1,2}),
E([]any{1,"a"}, Tuple{1,"a"}, [2]any{1,"a"}, Tuple{1,ByteString("a")}),
// Dict, map
E(D(),
M{}, map[int]bool{}),
E(D(1,bigInt("2")),
M{1:2.0}, map[int]int{1:2}),
E(D(1,"a"),
M{1:"a"}, map[int]string{1:"a"}),
E(D("a",1),
M{"a":1}),
E(D("a",1, None{},2),
M{"a":1, None{}:2}),
E(D("a",1, Bytes("a"),1),
M{"a":1, Bytes("a"):1}),
E(D("a",1, Bytes("a"),2),
M{"a":1, Bytes("a"):2}),
E(D("a",1), D(ByteString("a"),1)), E(D(ByteString("a"),1), D(Bytes("a"),1)),
E(D("a",1, Bytes("a"),1, ByteString("b"),2),
D(ByteString("a"),1, "b",2, Bytes("b"),2)),
// structs
E(Class{"mod","cls"}, Class{"mod","cls"}),
E(Call{Class{"mod","cls"}, Tuple{"a","b",3}},
Call{Class{"mod","cls"}, Tuple{ByteString("a"),"b",bigInt("3")}}),
E(Ref{1}, Ref{bigInt("1")}, Ref{1.0}),
E(tStructWithPrivate{"a",1}, tStructWithPrivate{ByteString("a"),bigInt("1")}),
E(tStructWithPrivate{"b",2}, tStructWithPrivate{"b",2.0}),
// pointers, as in builtin ==, are compared only by address
E(&i1), E(&i1_), E(&obj), E(&obj_),
// nil
E(nil),
}
// automatically test equality on Tuples/list from ^^^ data
testvAddSequences := func() {
l := len(testv)
for i := 0; i < l; i++ {
Ex := testv[i]
Ey := testv[(i+1)%l]
x0 := Ex[0]; x1 := Ex[1%len(Ex)]
y0 := Ey[0]; y1 := Ey[1%len(Ey)]
t1 := Tuple{x0,y0}; l1 := []any{x0,y0}
t2 := Tuple{x1,y1}; l2 := []any{x1,y1}
testv = append(testv, E(t1, t2, l1, l2))
}
}
testvAddSequences()
// and sequences of sequences
testvAddSequences()
// thash is used to invoke hash.
// if x is not hashable ok=false is returned instead of panic.
tseed := maphash.MakeSeed()
thash := func(x any) (h uint64, ok bool) {
defer func() {
r := recover()
if r != nil {
s, sok := r.(string)
if sok && strings.HasPrefix(s, "unhashable type: ") {
ok = false
h = 0
} else {
panic(r)
}
}
}()
return hash(tseed, x), true
}
// tequal is used to invoke equal.
// it automatically checks Go-extension, self-equal, symmetry and hash invariants:
//
// a==b ⇒ equal(a,b)
// equal(a,a) = y
// equal(a,b) = equal(b,a)
// equal(a,b) ⇒ hash(a) = hash(b)
tequal := func(a, b any) bool {
aa := equal(a, a)
bb := equal(b, b)
if !aa {
t.Errorf("not self-equal %T %#v", a,a)
}
if !bb {
t.Errorf("not self-equal %T %#v", b,b)
}
eq := equal(a, b)
qe := equal(b, a)
if eq != qe {
t.Errorf("equal not symmetric: %T %#v %T %#v; a == b: %v b == a: %v",
a,a, b,b, eq, qe)
}
ah, ahOk := thash(a)
bh, bhOk := thash(b)
if eq && ahOk && bhOk && !(ah == bh) {
t.Errorf("hash different of equal %T %#v hash:%x %T %#v hash:%x",
a,a,ah, b,b,bh)
}
goeq := false
func() {
// a == b can trigger "comparing uncomparable type ..."
// even if reflect reports both types as comparable
// (see mapTryAssign for details)
defer func() {
recover()
}()
goeq = (a == b)
}()
if goeq && !eq {
t.Errorf("equal is not extension of == %T %#v %T %#v",
a,a, b,b)
}
return eq
}
// EHas returns whether x ∈ E.
EHas := func(E tAllEqual, x any) bool {
for _, a := range E {
if tequal(a, x) {
return true
}
}
return false;
}
// do the tests
for i, E1 := range testv {
// ∀ a,b ∈ tEqualSet ⇒ equal(a,b) = y
for _, a := range E1 {
for _, b := range E1 {
if !tequal(a,b) {
t.Errorf("not equal %T %#v %T %#v", a,a, b,b)
}
}
}
// ∀ a ∈ tEqualSet
// ∀ c ∉ tEqualSet ⇒ equal(a,c) = n
for j, E2 := range testv {
if j == i {
continue
}
for _, a := range E1 {
for _, c := range E2 {
if EHas(E1, c) {
continue
}
if tequal(a,c) {
t.Errorf("equal %T %#v %T %#v", a,a, c,c)
}
}
}
}
}
}
// TestDict verifies Dict.
func TestDict(t *testing.T) {
d := NewDict()
// assertData asserts that d has data exactly as specified by provided key,value pairs.
assertData := func(kvok ...any) {
t.Helper()
if len(kvok) % 2 != 0 {
panic("kvok % 2 != 0")
}
lok := len(kvok)/2
kvokGet := func(k any) (any, bool) {
t.Helper()
for i := 0; i < lok; i++ {
kok := kvok[2*i]
vok := kvok[2*i+1]
if reflect.TypeOf(k) == reflect.TypeOf(kok) &&
equal(k, kok) {
return vok, true
}
}
return nil, false
}
bad := false
badf := func(format string, argv ...any) {
t.Helper()
bad = true
t.Errorf(format, argv...)
}
l := d.Len()
if l != lok {
badf("len: have: %d want: %d", l, lok)
}
d.Iter()(func(k,v any) bool {
t.Helper()
vok, ok := kvokGet(k)
if !ok {
badf("unexpected key %#v", k)
}
if v != vok {
badf("key %T %#v -> value %#T %#v ; want %T %#v", k,k, v,v, vok,vok)
}
return true
})
if bad {
t.Fatalf("\nd: %#v\nkvok: %#v", d, kvok)
}
}
// assertGet asserts that d.Get(k) results in exactly vok or any element from vokExtra.
assertGet := func(k any, vok any, vokExtra ...any) {
t.Helper()
v := d.Get(k)
if v == vok {
return
}
for _, eok := range vokExtra {
if v == eok {
return
}
}
emsg := fmt.Sprintf("get %#v: have: %#v want: %#v", k, v, vok)
for _, eok := range vokExtra {
emsg += fmt.Sprintf(" ∪ %#v", eok)
}
emsg += fmt.Sprintf("\nd: %#v", d)
t.Fatal(emsg)
}
// numbers
assertData()
d.Set(1, "x")
assertData(1,"x")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
d.Del(7)
assertData(1,"x")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
d.Set(2.5, "y")
assertData(1,"x", 2.5,"y")
assertGet(1, "x")
assertGet(1.0, "x")
assertGet(bigInt("1"), "x")
assertGet(complex(1,0), "x")
assertGet(2, nil)
assertGet(2.5, "y")
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), "y")
d.Del(1)
assertData(2.5,"y")
assertGet(1, nil)
assertGet(1.0, nil)
assertGet(bigInt("1"), nil)
assertGet(complex(1,0), nil)
assertGet(2, nil)
assertGet(2.5, "y")
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), "y")
d.Del(2.5)
assertData()
assertGet(1, nil)
assertGet(1.0, nil)
assertGet(bigInt("1"), nil)
assertGet(complex(1,0), nil)
assertGet(2, nil)
assertGet(2.5, nil)
assertGet(bigInt("2"), nil)
assertGet(complex(2.5,0), nil)
// strings/bytes
assertData()
assertGet("abc", nil)
d.Set("abc", "a")
assertData("abc","a")
assertGet("abc", "a")
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), "a")
d.Set(Bytes("abc"), "b")
assertData("abc","a", Bytes("abc"),"b")
assertGet("abc", "a")
assertGet(Bytes("abc"), "b")
assertGet(ByteString("abc"), "a", "b")
d.Set(ByteString("abc"), "c")
assertData(ByteString("abc"),"c")
assertGet("abc", "c")
assertGet(Bytes("abc"), "c")
assertGet(ByteString("abc"), "c")
d.Del("abc")
assertData()
assertGet("abc", nil)
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), nil)
d.Set("abc", "a")
assertData("abc","a")
assertGet("abc", "a")
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), "a")
d.Set(Bytes("abc"), "b")
assertData("abc","a", Bytes("abc"),"b")
assertGet("abc", "a")
assertGet(Bytes("abc"), "b")
assertGet(ByteString("abc"), "a", "b")
d.Del(ByteString("abc"))
assertData()
assertGet("abc", nil)
assertGet(Bytes("abc"), nil)
assertGet(ByteString("abc"), nil)
// None, tuple
assertData()
d.Set(None{}, "n")
assertData(None{},"n")
assertGet(None{}, "n")
assertGet(Tuple{}, nil)
d.Set(Tuple{}, "t")
assertData(None{},"n", Tuple{},"t")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
d.Set(Tuple{1,2,"a"}, "t12a")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, "t12a")
d.Set(Tuple{1,2,Bytes("a")}, "t12b")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a", Tuple{1,2,Bytes("a")},"t12b")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, "t12b")
assertGet(Tuple{1,2,ByteString("a")}, "t12a", "t12b")
d.Set(Tuple{1,2,ByteString("a")}, "t12c")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,ByteString("a")},"t12c")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12c")
assertGet(Tuple{1,2,Bytes("a")}, "t12c")
assertGet(Tuple{1,2,ByteString("a")}, "t12c")
d.Set(Tuple{1,2,"a"}, "t12a")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, "t12a")
d.Set(Tuple{1,2,Bytes("a")}, "t12b")
assertData(None{},"n", Tuple{},"t", Tuple{1,2,"a"},"t12a", Tuple{1,2,Bytes("a")},"t12b")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, "t12a")
assertGet(Tuple{1,2,Bytes("a")}, "t12b")
assertGet(Tuple{1,2,ByteString("a")}, "t12a", "t12b")
d.Del(Tuple{1,2,ByteString("a")})
assertData(None{},"n", Tuple{},"t")
assertGet(None{}, "n")
assertGet(Tuple{}, "t")
assertGet(Tuple{1,2}, nil)
assertGet(Tuple{1,2,"a"}, nil)
assertGet(Tuple{1,2,Bytes("a")}, nil)
assertGet(Tuple{1,2,ByteString("a")}, nil)
// structs
d = NewDict()
d.Set(Class{"a","b"}, 1)
d.Set(Class{"c","d"}, 2)
d.Set(Ref{"a"}, 3)
d.Set(tStructWithPrivate{"x","y"}, 4)
assertData(Class{"a","b"},1, Class{"c","d"},2, Ref{"a"},3, tStructWithPrivate{"x","y"},4)
assertGet(Class{"a","b"}, 1)
assertGet(Class{"c","d"}, 2)
assertGet(Class{"x","y"}, nil)
assertGet(Ref{"a"}, 3)
assertGet(Ref{"x"}, nil)
assertGet(tStructWithPrivate{"x","y"}, 4)
assertGet(tStructWithPrivate{"p","q"}, nil)
// pointers
i := 1
j := 1
k := 1
x := Class{"a","b"}
y := Class{"a","b"}
z := Class{"a","b"}
d = NewDict()
d.Set(&i, 1)
d.Set(&j, 2)
d.Set(&x, 3)
d.Set(&y, 4)
assertData(&i,1, &j,2, &x,3, &y,4)
assertGet(&i, 1)
assertGet(&j, 2)
assertGet(&k, nil)
assertGet(&x, 3)
assertGet(&y, 4)
assertGet(&z, nil)
// NewDictWithSizeHint
d = NewDictWithSizeHint(100)
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
// NewDictWithData
d = NewDictWithData("a",1, 2,"b")
assertData("a",1, 2,"b")
assertGet(1, nil)
assertGet(2, "b")
assertGet("a", 1)
assertGet("b", nil)
// unhashable types
vbad := []any{
[]any{},
[]any{1,2,3},
[]int{},
[]int{1,2,3},
NewDict(),
map[any]any{},
map[int]bool{},
Ref{[]any{}},
tStructWithPrivate{1,[]any{}},
tStructWithPrivate{[]any{},1},
tStructWithPrivate{[]any{},[]any{}},
}
assertPanics := func(subj any, errPrefix string, f func()) {
t.Helper()
defer func() {
t.Helper()
r := recover()
if r == nil {
t.Errorf("%#v: no panic", subj)
return
}
s, ok := r.(string)
if ok && strings.HasPrefix(s, errPrefix) {
// ok
} else {
panic(r)
}
}()
f()
}
for _, k := range vbad {
assertUnhashable := func(f func()) {
t.Helper()
assertPanics(k, "unhashable type: ", f)
}
assertUnhashable(func() { d.Get(k) })
assertUnhashable(func() { d.Set(k, 1) })
assertUnhashable(func() { d.Del(k) })
assertUnhashable(func() { NewDictWithData(k,1) })
}
// = ~nil
d = Dict{}
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
d.Del(1)
assertData()
assertGet(1, nil)
assertGet(2, nil)
assertGet("a", nil)
assertGet("b", nil)
assertPanics("nil.Set", "Set called on nil map", func() { d.Set(1, "x") })
}
// benchmarks for map and Dict compare them from performance point of view.
func BenchmarkMapGet(b *testing.B) {
m := map[any]any{}
for i := 0; i < 100; i++ {
m[i] = i
}
m["abc"] = 777
b.Run("string", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = m["abc"]
}
})
b.Run("int", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = m[77]
}
})
}
func BenchmarkDictGet(b *testing.B) {
d := NewDict()
for i := 0; i < 100; i++ {
d.Set(i, i)
}
d.Set("abc", 777)
b.Run("string", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = d.Get("abc")
}
})
b.Run("int", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = d.Get(77)
}
})
}
module github.com/kisielk/og-rek
go 1.18
require github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced
require golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced h1:HxlRMDx/VeRqzj3nvqX9k4tjeBcEIkoNHDJPsS389hs=
github.com/aristanetworks/gomap v0.0.0-20230726210543-f4e41046dced/go.mod h1:p7lmI+ecoe1RTyD11SPXWsSQ3H+pJ4cp5y7vtKW4QdM=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
//go:build !go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return new(big.Float).SetInt(b).Float64()
}
//go:build go1.21
package ogórek
import (
"math/big"
)
func bigInt_Float64(b *big.Int) (float64, big.Accuracy) {
return b.Float64()
}
//go:build !go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
var h maphash.Hash
h.SetSeed(seed)
h.WriteString(s)
return h.Sum64()
}
//go:build go1.19
package ogórek
import (
"hash/maphash"
)
func maphash_String(seed maphash.Seed, s string) uint64 {
return maphash.String(seed, s)
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment