Commit 01cd22c6 authored by Josselin Costanzi's avatar Josselin Costanzi Committed by Keith Randall

bytes: add optimized countByte for amd64

Use SSE/AVX2 when counting a single byte.
Inspired from runtime indexbyte implementation.

Benchmark against previous implementation, where
1 byte in every 8 is the one we are looking for:

* On a machine without AVX2
name               old time/op   new time/op     delta
CountSingle/10-4    61.8ns ±10%     15.6ns ±11%    -74.83%  (p=0.000 n=10+10)
CountSingle/32-4     100ns ± 4%       17ns ±10%    -82.54%  (p=0.000 n=10+9)
CountSingle/4K-4    9.66µs ± 3%     0.37µs ± 6%    -96.21%  (p=0.000 n=10+10)
CountSingle/4M-4    11.0ms ± 6%      0.4ms ± 4%    -96.04%  (p=0.000 n=10+10)
CountSingle/64M-4    194ms ± 8%        8ms ± 2%    -95.64%  (p=0.000 n=10+10)

name               old speed     new speed       delta
CountSingle/10-4   162MB/s ±10%    645MB/s ±10%   +297.00%  (p=0.000 n=10+10)
CountSingle/32-4   321MB/s ± 5%   1844MB/s ± 9%   +474.79%  (p=0.000 n=10+9)
CountSingle/4K-4   424MB/s ± 3%  11169MB/s ± 6%  +2533.10%  (p=0.000 n=10+10)
CountSingle/4M-4   381MB/s ± 7%   9609MB/s ± 4%  +2421.88%  (p=0.000 n=10+10)
CountSingle/64M-4  346MB/s ± 7%   7924MB/s ± 2%  +2188.78%  (p=0.000 n=10+10)

* On a machine with AVX2
name               old time/op   new time/op     delta
CountSingle/10-8    37.1ns ± 3%      8.2ns ± 1%    -77.80%  (p=0.000 n=10+10)
CountSingle/32-8    66.1ns ± 3%      9.8ns ± 2%    -85.23%  (p=0.000 n=10+10)
CountSingle/4K-8    7.36µs ± 3%     0.11µs ± 1%    -98.54%  (p=0.000 n=10+10)
CountSingle/4M-8    7.46ms ± 2%     0.15ms ± 2%    -97.95%  (p=0.000 n=10+9)
CountSingle/64M-8    124ms ± 2%        6ms ± 4%    -95.09%  (p=0.000 n=10+10)

name               old speed     new speed       delta
CountSingle/10-8   269MB/s ± 3%   1213MB/s ± 1%   +350.32%  (p=0.000 n=10+10)
CountSingle/32-8   484MB/s ± 4%   3277MB/s ± 2%   +576.66%  (p=0.000 n=10+10)
CountSingle/4K-8   556MB/s ± 3%  37933MB/s ± 1%  +6718.36%  (p=0.000 n=10+10)
CountSingle/4M-8   562MB/s ± 2%  27444MB/s ± 3%  +4783.43%  (p=0.000 n=10+9)
CountSingle/64M-8  543MB/s ± 2%  11054MB/s ± 3%  +1935.81%  (p=0.000 n=10+10)

Fixes #19411

Change-Id: Ieaf20b1fabccabe767c55c66e242e86f3617f883
Reviewed-on: https://go-review.googlesource.com/38258
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarKeith Randall <khr@golang.org>
parent 0ebaca6b
......@@ -46,9 +46,8 @@ func explode(s []byte, n int) [][]byte {
return a[0:na]
}
// Count counts the number of non-overlapping instances of sep in s.
// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
func Count(s, sep []byte) int {
// countGeneric actualy implements Count
func countGeneric(s, sep []byte) int {
n := 0
// special case
if len(sep) == 0 {
......
......@@ -10,6 +10,7 @@ package bytes
// indexShortStr requires 2 <= len(c) <= shortStringLen
func indexShortStr(s, c []byte) int // ../runtime/asm_$GOARCH.s
func supportAVX2() bool // ../runtime/asm_$GOARCH.s
func supportPOPCNT() bool // ../runtime/asm_$GOARCH.s
var shortStringLen int
......@@ -94,6 +95,18 @@ func Index(s, sep []byte) int {
return -1
}
// Special case for when we must count occurences of a single byte.
func countByte(s []byte, c byte) int
// Count counts the number of non-overlapping instances of sep in s.
// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
func Count(s, sep []byte) int {
if len(sep) == 1 && supportPOPCNT() {
return countByte(s, sep[0])
}
return countGeneric(s, sep)
}
// primeRK is the prime base used in Rabin-Karp algorithm.
const primeRK = 16777619
......
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
// We use:
// SI: data
// BX: data len
// AL: byte sought
// This require the POPCNT instruction
TEXT ·countByte(SB),NOSPLIT,$0-40
MOVQ s+0(FP), SI
MOVQ s_len+8(FP), BX
MOVB c+24(FP), AL
// Shuffle X0 around so that each byte contains
// the character we're looking for.
MOVD AX, X0
PUNPCKLBW X0, X0
PUNPCKLBW X0, X0
PSHUFL $0, X0, X0
CMPQ BX, $16
JLT small
MOVQ $0, R12 // Accumulator
MOVQ SI, DI
CMPQ BX, $32
JA avx2
sse:
LEAQ -16(SI)(BX*1), AX // AX = address of last 16 bytes
JMP sseloopentry
sseloop:
// Move the next 16-byte chunk of the data into X1.
MOVOU (DI), X1
// Compare bytes in X0 to X1.
PCMPEQB X0, X1
// Take the top bit of each byte in X1 and put the result in DX.
PMOVMSKB X1, DX
// Count number of matching bytes
POPCNTL DX, DX
// Accumulate into R12
ADDQ DX, R12
// Advance to next block.
ADDQ $16, DI
sseloopentry:
CMPQ DI, AX
JBE sseloop
// Get the number of bytes to consider in the last 16 bytes
ANDQ $15, BX
JZ end
// Create mask to ignore overlap between previous 16 byte block
// and the next.
MOVQ $16,CX
SUBQ BX, CX
MOVQ $0xFFFF, R10
SARQ CL, R10
SALQ CL, R10
// Process the last 16-byte chunk. This chunk may overlap with the
// chunks we've already searched so we need to mask part of it.
MOVOU (AX), X1
PCMPEQB X0, X1
PMOVMSKB X1, DX
// Apply mask
ANDQ R10, DX
POPCNTL DX, DX
ADDQ DX, R12
end:
MOVQ R12, ret+32(FP)
RET
// handle for lengths < 16
small:
TESTQ BX, BX
JEQ endzero
// Check if we'll load across a page boundary.
LEAQ 16(SI), AX
TESTW $0xff0, AX
JEQ endofpage
// We must ignore high bytes as they aren't part of our slice.
// Create mask.
MOVB BX, CX
MOVQ $1, R10
SALQ CL, R10
SUBQ $1, R10
// Load data
MOVOU (SI), X1
// Compare target byte with each byte in data.
PCMPEQB X0, X1
// Move result bits to integer register.
PMOVMSKB X1, DX
// Apply mask
ANDQ R10, DX
POPCNTL DX, DX
// Directly return DX, we don't need to accumulate
// since we have <16 bytes.
MOVQ DX, ret+32(FP)
RET
endzero:
MOVQ $0, ret+32(FP)
RET
endofpage:
// We must ignore low bytes as they aren't part of our slice.
MOVQ $16,CX
SUBQ BX, CX
MOVQ $0xFFFF, R10
SARQ CL, R10
SALQ CL, R10
// Load data into the high end of X1.
MOVOU -16(SI)(BX*1), X1
// Compare target byte with each byte in data.
PCMPEQB X0, X1
// Move result bits to integer register.
PMOVMSKB X1, DX
// Apply mask
ANDQ R10, DX
// Directly return DX, we don't need to accumulate
// since we have <16 bytes.
POPCNTL DX, DX
MOVQ DX, ret+32(FP)
RET
avx2:
CMPB runtime·support_avx2(SB), $1
JNE sse
MOVD AX, X0
LEAQ -32(SI)(BX*1), R11
VPBROADCASTB X0, Y1
avx2_loop:
VMOVDQU (DI), Y2
VPCMPEQB Y1, Y2, Y3
VPMOVMSKB Y3, DX
POPCNTL DX, DX
ADDQ DX, R12
ADDQ $32, DI
CMPQ DI, R11
JLE avx2_loop
// If last block is already processed,
// skip to the end.
CMPQ DI, R11
JEQ endavx
// Load address of the last 32 bytes.
// There is an overlap with the previous block.
MOVQ R11, DI
VMOVDQU (DI), Y2
VPCMPEQB Y1, Y2, Y3
VPMOVMSKB Y3, DX
// Exit AVX mode.
VZEROUPPER
// Create mask to ignore overlap between previous 32 byte block
// and the next.
ANDQ $31, BX
MOVQ $32,CX
SUBQ BX, CX
MOVQ $0xFFFFFFFF, R10
SARQ CL, R10
SALQ CL, R10
// Apply mask
ANDQ R10, DX
POPCNTL DX, DX
ADDQ DX, R12
MOVQ R12, ret+32(FP)
RET
endavx:
// Exit AVX mode.
VZEROUPPER
MOVQ R12, ret+32(FP)
RET
......@@ -39,3 +39,9 @@ func Index(s, sep []byte) int {
}
return -1
}
// Count counts the number of non-overlapping instances of sep in s.
// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
func Count(s, sep []byte) int {
return countGeneric(s, sep)
}
......@@ -97,6 +97,12 @@ func Index(s, sep []byte) int {
return -1
}
// Count counts the number of non-overlapping instances of sep in s.
// If sep is an empty slice, Count returns 1 + the number of Unicode code points in s.
func Count(s, sep []byte) int {
return countGeneric(s, sep)
}
// primeRK is the prime base used in Rabin-Karp algorithm.
const primeRK = 16777619
......
......@@ -396,6 +396,79 @@ func TestIndexRune(t *testing.T) {
}
}
// test count of a single byte across page offsets
func TestCountByte(t *testing.T) {
b := make([]byte, 5015) // bigger than a page
windows := []int{1, 2, 3, 4, 15, 16, 17, 31, 32, 33, 63, 64, 65, 128}
testCountWindow := func(i, window int) {
for j := 0; j < window; j++ {
b[i+j] = byte(100)
p := Count(b[i:i+window], []byte{100})
if p != j+1 {
t.Errorf("TestCountByte.Count(%q, 100) = %d", b[i:i+window], p)
}
pGeneric := CountGeneric(b[i:i+window], []byte{100})
if pGeneric != j+1 {
t.Errorf("TestCountByte.CountGeneric(%q, 100) = %d", b[i:i+window], p)
}
}
}
maxWnd := windows[len(windows)-1]
for i := 0; i <= 2*maxWnd; i++ {
for _, window := range windows {
if window > len(b[i:]) {
window = len(b[i:])
}
testCountWindow(i, window)
for j := 0; j < window; j++ {
b[i+j] = byte(0)
}
}
}
for i := 4096 - (maxWnd + 1); i < len(b); i++ {
for _, window := range windows {
if window > len(b[i:]) {
window = len(b[i:])
}
testCountWindow(i, window)
for j := 0; j < window; j++ {
b[i+j] = byte(0)
}
}
}
}
// Make sure we don't count bytes outside our window
func TestCountByteNoMatch(t *testing.T) {
b := make([]byte, 5015)
windows := []int{1, 2, 3, 4, 15, 16, 17, 31, 32, 33, 63, 64, 65, 128}
for i := 0; i <= len(b); i++ {
for _, window := range windows {
if window > len(b[i:]) {
window = len(b[i:])
}
// Fill the window with non-match
for j := 0; j < window; j++ {
b[i+j] = byte(100)
}
// Try to find something that doesn't exist
p := Count(b[i:i+window], []byte{0})
if p != 0 {
t.Errorf("TestCountByteNoMatch(%q, 0) = %d", b[i:i+window], p)
}
pGeneric := CountGeneric(b[i:i+window], []byte{0})
if pGeneric != 0 {
t.Errorf("TestCountByteNoMatch.CountGeneric(%q, 100) = %d", b[i:i+window], p)
}
for j := 0; j < window; j++ {
b[i+j] = byte(0)
}
}
}
}
var bmbuf []byte
func valName(x int) string {
......@@ -589,6 +662,26 @@ func BenchmarkCountEasy(b *testing.B) {
})
}
func BenchmarkCountSingle(b *testing.B) {
benchBytes(b, indexSizes, func(b *testing.B, n int) {
buf := bmbuf[0:n]
step := 8
for i := 0; i < len(buf); i += step {
buf[i] = 1
}
expect := (len(buf) + (step - 1)) / step
for i := 0; i < b.N; i++ {
j := Count(buf, []byte{1})
if j != expect {
b.Fatal("bad count", j, expect)
}
}
for i := 0; i < len(buf); i++ {
buf[i] = 0
}
})
}
type ExplodeTest struct {
s string
n int
......
......@@ -7,3 +7,4 @@ package bytes
// Export func for testing
var IndexBytePortable = indexBytePortable
var EqualPortable = equalPortable
var CountGeneric = countGeneric
......@@ -17,6 +17,7 @@ runtime/asm_amd64.s: [GOARCH] cannot check cross-package assembly function: Comp
runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: indexShortStr is in package bytes
runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportAVX2 is in package strings
runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportAVX2 is in package bytes
runtime/asm_amd64.s: [amd64] cannot check cross-package assembly function: supportPOPCNT is in package bytes
// Intentionally missing declarations. These are special assembly routines.
// Some are jumped into from other routines, with values in specific registers.
......
......@@ -91,8 +91,13 @@ testbmi1:
testbmi2:
MOVB $0, runtime·support_bmi2(SB)
TESTL $(1<<8), runtime·cpuid_ebx7(SB) // check for BMI2 bit
JEQ nocpuinfo
JEQ testpopcnt
MOVB $1, runtime·support_bmi2(SB)
testpopcnt:
MOVB $0, runtime·support_popcnt(SB)
TESTL $(1<<23), runtime·cpuid_ecx(SB) // check for POPCNT bit
JEQ nocpuinfo
MOVB $1, runtime·support_popcnt(SB)
nocpuinfo:
// if there is an _cgo_init, call it.
......@@ -1697,6 +1702,11 @@ TEXT bytes·supportAVX2(SB),NOSPLIT,$0-1
MOVB AX, ret+0(FP)
RET
TEXT bytes·supportPOPCNT(SB),NOSPLIT,$0-1
MOVBLZX runtime·support_popcnt(SB), AX
MOVB AX, ret+0(FP)
RET
TEXT strings·indexShortStr(SB),NOSPLIT,$0-40
MOVQ s+0(FP), DI
// We want len in DX and AX, because PCMPESTRI implicitly consumes them
......
......@@ -728,6 +728,7 @@ var (
support_avx2 bool
support_bmi1 bool
support_bmi2 bool
support_popcnt bool
goarm uint8 // set by cmd/link on arm systems
framepointer_enabled bool // set by cmd/link
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment