Skip to content

Commit 16d3192

Browse files
committed
Use z3 to find a state mapping that only need u32 transition table
This shrinks the table size to 1KiB for less cache pressure, and produces faster code on platforms that only support 32-bit shift. Though, it does not affect the throughput on 64-bit platforms when the table is already fully in cache.
1 parent 33c076e commit 16d3192

File tree

2 files changed

+138
-110
lines changed

2 files changed

+138
-110
lines changed

library/core/src/str/solve_dfa.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
# Use z3 to solve UTF-8 validation DFA for offset and transition table,
3+
# in order to encode transition table into u32.
4+
# We minimize the output variables in the solution to make it deterministic.
5+
# Ref: <https://gist.github.com/dougallj/166e326de6ad4cf2c94be97a204c025f>
6+
# See more detail explanation in `./validations.rs`.
7+
#
8+
# It is expected to find a solution in <30s on a modern machine, and the
9+
# solution is appended to the end of this file.
10+
from z3 import *
11+
12+
STATE_CNT = 9
13+
14+
# The transition table.
15+
# A value X on column Y means state Y should transition to state X on some
16+
# input bytes. We assign state 0 as ERROR and state 1 as ACCEPT (initial).
17+
# Eg. first line: for input byte 00..=7F, transition S1 -> S1, others -> S0.
18+
TRANSITIONS = [
19+
# 0 1 2 3 4 5 6 7 8
20+
# First bytes
21+
((0, 1, 0, 0, 0, 0, 0, 0, 0), "00-7F"),
22+
((0, 2, 0, 0, 0, 0, 0, 0, 0), "C2-DF"),
23+
((0, 3, 0, 0, 0, 0, 0, 0, 0), "E0"),
24+
((0, 4, 0, 0, 0, 0, 0, 0, 0), "E1-EC, EE-EF"),
25+
((0, 5, 0, 0, 0, 0, 0, 0, 0), "ED"),
26+
((0, 6, 0, 0, 0, 0, 0, 0, 0), "F0"),
27+
((0, 7, 0, 0, 0, 0, 0, 0, 0), "F1-F3"),
28+
((0, 8, 0, 0, 0, 0, 0, 0, 0), "F4"),
29+
# Continuation bytes
30+
((0, 0, 1, 0, 2, 2, 0, 4, 4), "80-8F"),
31+
((0, 0, 1, 0, 2, 2, 4, 4, 0), "90-9F"),
32+
((0, 0, 1, 2, 2, 0, 4, 4, 0), "A0-BF"),
33+
# Illegal
34+
((0, 0, 0, 0, 0, 0, 0, 0, 0), "C0-C1, F5-FF"),
35+
]
36+
37+
o = Optimize()
38+
offsets = [BitVec(f'o{i}', 32) for i in range(STATE_CNT)]
39+
trans_table = [BitVec(f't{i}', 32) for i in range(len(TRANSITIONS))]
40+
41+
# Add some guiding constraints to make solving faster.
42+
o.add(offsets[0] == 0)
43+
o.add(trans_table[-1] == 0)
44+
45+
for i in range(len(offsets)):
46+
o.add(offsets[i] < 32 - 5) # Do not over-shift. It's not necessary but makes solving faster.
47+
for j in range(i):
48+
o.add(offsets[i] != offsets[j])
49+
for trans, (targets, _) in zip(trans_table, TRANSITIONS):
50+
for src, tgt in enumerate(targets):
51+
o.add((LShR(trans, offsets[src]) & 31) == offsets[tgt])
52+
53+
# Minimize ordered outputs to get a unique solution.
54+
goal = Concat(*offsets, *trans_table)
55+
o.minimize(goal)
56+
print(o.check())
57+
print('Offset[]= ', [o.model()[i].as_long() for i in offsets])
58+
print('Transitions:')
59+
for (_, label), v in zip(TRANSITIONS, [o.model()[i].as_long() for i in trans_table]):
60+
print(f'{label:14} => {v:#10x}, // {v:032b}')
61+
62+
# Output should be deterministic:
63+
# sat
64+
# Offset[]= [0, 6, 16, 19, 1, 25, 11, 18, 24]
65+
# Transitions:
66+
# 00-7F => 0x180, // 00000000000000000000000110000000
67+
# C2-DF => 0x400, // 00000000000000000000010000000000
68+
# E0 => 0x4c0, // 00000000000000000000010011000000
69+
# E1-EC, EE-EF => 0x40, // 00000000000000000000000001000000
70+
# ED => 0x640, // 00000000000000000000011001000000
71+
# F0 => 0x2c0, // 00000000000000000000001011000000
72+
# F1-F3 => 0x480, // 00000000000000000000010010000000
73+
# F4 => 0x600, // 00000000000000000000011000000000
74+
# 80-8F => 0x21060020, // 00100001000001100000000000100000
75+
# 90-9F => 0x20060820, // 00100000000001100000100000100000
76+
# A0-BF => 0x860820, // 00000000100001100000100000100000
77+
# C0-C1, F5-FF => 0x0, // 00000000000000000000000000000000

library/core/src/str/validations.rs

Lines changed: 61 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
use super::Utf8Error;
44
use crate::intrinsics::{const_eval_select, unlikely};
5-
use crate::mem;
65

76
/// Returns the initial codepoint accumulator for the first byte.
87
/// The first byte is special, only want bottom 5 bits for width 2, 4 bits
@@ -112,34 +111,33 @@ where
112111
Some(ch)
113112
}
114113

115-
// The transition table of shift-based DFA for UTF-8 validation.
114+
// The shift-based DFA algorithm for UTF-8 validation.
116115
// Ref: <https://gist.github.com/pervognsen/218ea17743e1442e59bb60d29b1aa725>
117116
//
118117
// In short, we encode DFA transitions in an array `TRANS_TABLE` such that:
119118
// ```
120119
// TRANS_TABLE[next_byte] =
121-
// (target_state1 * BITS_PER_STATE) << (source_state1 * BITS_PER_STATE) |
122-
// (target_state2 * BITS_PER_STATE) << (source_state2 * BITS_PER_STATE) |
120+
// OFFSET[target_state1] << OFFSET[source_state1] |
121+
// OFFSET[target_state2] << OFFSET[source_state2] |
123122
// ...
124123
// ```
125-
// Thanks to pre-multiplication, we can execute the DFA with one statement per byte:
124+
// Where `OFFSET[]` is a compile-time map from each state to a distinct 0..32 value.
125+
//
126+
// To execute the DFA:
126127
// ```
127-
// let state = initial_state * BITS_PER_STATE;
128+
// let state = OFFSET[initial_state];
128129
// for byte in .. {
129130
// state = TRANS_TABLE[byte] >> (state & ((1 << BITS_PER_STATE) - 1));
130131
// }
131132
// ```
132-
// By choosing `BITS_PER_STATE = 6` and `state: u64`, we can replace the masking by `wrapping_shr`.
133+
// By choosing `BITS_PER_STATE = 5` and `state: u32`, we can replace the masking by `wrapping_shr`
134+
// and it becomes free on modern ISAs, including x86, x86_64 and ARM.
135+
//
133136
// ```
134137
// // shrx state, qword ptr [table_addr + 8 * byte], state # On x86-64-v3
135138
// state = TRANS_TABLE[byte].wrapping_shr(state);
136139
// ```
137140
//
138-
// On platform without 64-bit shift, especially i686, we split the `u64` next-state into
139-
// `[u32; 2]`, and each `u32` stores 5 * BITS_PER_STATE = 30 bits. In this way, state transition
140-
// can be done in only 32-bit shifts and a conditional move, which is several times faster
141-
// (in latency) than ordinary 64-bit shift (SHRD).
142-
//
143141
// The DFA is directly derived from UTF-8 syntax from the RFC3629:
144142
// <https://datatracker.ietf.org/doc/html/rfc3629#section-4>.
145143
// We assign S0 as ERROR and S1 as ACCEPT. DFA starts at S1.
@@ -152,116 +150,49 @@ where
152150
// <S1> (%xE1-EC / %xEE-EF) <S4> 2( UTF8-tail ) /
153151
// <S1> %xED <S5> %x80-9F <S2> UTF8-tail
154152
// UTF8-4 = <S1> %xF0 <S6> %x90-BF <S4> 2( UTF8-tail ) /
155-
// <S1> %xF4 <S7> %x80-8F <S4> 2( UTF8-tail ) /
156-
// <S1> %xF1-F3 <S8> UTF8-tail <S4> 2( UTF8-tail )
157-
//
153+
// <S1> %xF1-F3 <S7> UTF8-tail <S4> 2( UTF8-tail ) /
154+
// <S1> %xF4 <S8> %x80-8F <S4> 2( UTF8-tail )
158155
// UTF8-tail = %x80-BF # Inlined into above usages.
159-
const BITS_PER_STATE: u32 = 6;
156+
//
157+
// You may notice that encoding 9 states with 5bits per state into 32bit seems impossible,
158+
// but we exploit overlapping bits to find a possible `OFFSET[]` and `TRANS_TABLE[]` solution.
159+
// The SAT solver to find such (minimal) solution is in `./solve_dfa.py`.
160+
// The solution is also appended to the end of that file and is verifiable.
161+
const BITS_PER_STATE: u32 = 5;
160162
const STATE_MASK: u32 = (1 << BITS_PER_STATE) - 1;
161163
const STATE_CNT: usize = 9;
162-
#[allow(clippy::all)]
163-
const ST_ERROR: u32 = 0 * BITS_PER_STATE as u32;
164-
#[allow(clippy::all)]
165-
const ST_ACCEPT: u32 = 1 * BITS_PER_STATE as u32;
166-
167-
/// Platforms that does not have efficient 64-bit shift and should use 32-bit shift fallback.
168-
const USE_SHIFT32: bool = cfg!(all(
169-
any(target_pointer_width = "16", target_pointer_width = "32"),
170-
// WASM32 supports 64-bit shift.
171-
not(target_arch = "wasm32"),
172-
));
164+
const ST_ERROR: u32 = OFFSETS[0];
165+
const ST_ACCEPT: u32 = OFFSETS[1];
166+
// See the end of `./solve_dfa.py`.
167+
const OFFSETS: [u32; STATE_CNT] = [0, 6, 16, 19, 1, 25, 11, 18, 24];
173168

174-
// After storing STATE_CNT * BITS_PER_STATE = 54bits on 64-bit platform, or (STATE_CNT - 5)
175-
// * BITS_PER_STATE = 24bits on 32-bit platform, we still have some high bits left.
176-
// They will never be used via state transition.
177-
// We merge lookup table from first byte -> UTF-8 length, to these highest bits.
178-
const UTF8_LEN_HIBITS: u32 = 4;
179-
180-
static TRANS_TABLE: [u64; 256] = {
181-
let mut table = [0u64; 256];
169+
static TRANS_TABLE: [u32; 256] = {
170+
let mut table = [0u32; 256];
182171
let mut b = 0;
183172
while b < 256 {
184-
// Target states indexed by starting states.
185-
let mut to = [0u64; STATE_CNT];
186-
to[0] = 0;
187-
to[1] = match b {
188-
0x00..=0x7F => 1,
189-
0xC2..=0xDF => 2,
190-
0xE0 => 3,
191-
0xE1..=0xEC | 0xEE..=0xEF => 4,
192-
0xED => 5,
193-
0xF0 => 6,
194-
0xF4 => 7,
195-
0xF1..=0xF3 => 8,
196-
_ => 0,
197-
};
198-
to[2] = match b {
199-
0x80..=0xBF => 1,
200-
_ => 0,
201-
};
202-
to[3] = match b {
203-
0xA0..=0xBF => 2,
204-
_ => 0,
205-
};
206-
to[4] = match b {
207-
0x80..=0xBF => 2,
208-
_ => 0,
209-
};
210-
to[5] = match b {
211-
0x80..=0x9F => 2,
212-
_ => 0,
173+
// See the end of `./solve_dfa.py`.
174+
table[b] = match b as u8 {
175+
0x00..=0x7F => 0x180,
176+
0xC2..=0xDF => 0x400,
177+
0xE0 => 0x4C0,
178+
0xE1..=0xEC | 0xEE..=0xEF => 0x40,
179+
0xED => 0x640,
180+
0xF0 => 0x2C0,
181+
0xF1..=0xF3 => 0x480,
182+
0xF4 => 0x600,
183+
0x80..=0x8F => 0x21060020,
184+
0x90..=0x9F => 0x20060820,
185+
0xA0..=0xBF => 0x860820,
186+
0xC0..=0xC1 | 0xF5..=0xFF => 0x0,
213187
};
214-
to[6] = match b {
215-
0x90..=0xBF => 4,
216-
_ => 0,
217-
};
218-
to[7] = match b {
219-
0x80..=0x8F => 4,
220-
_ => 0,
221-
};
222-
to[8] = match b {
223-
0x80..=0xBF => 4,
224-
_ => 0,
225-
};
226-
227-
// On platforms without 64-bit shift, align states 5..10 to 32-bit boundary.
228-
// See docs above for details.
229-
let mut bits = 0u64;
230-
let mut j = 0;
231-
while j < to.len() {
232-
let to_off =
233-
to[j] * BITS_PER_STATE as u64 + if USE_SHIFT32 && to[j] >= 5 { 2 } else { 0 };
234-
let off = j as u32 * BITS_PER_STATE + if USE_SHIFT32 && j >= 5 { 2 } else { 0 };
235-
bits |= to_off << off;
236-
j += 1;
237-
}
238-
239-
let utf8_len = match b {
240-
0x00..=0x7F => 1,
241-
0xC2..=0xDF => 2,
242-
0xE0..=0xEF => 3,
243-
0xF0..=0xF4 => 4,
244-
_ => 0,
245-
};
246-
bits |= utf8_len << (64 - UTF8_LEN_HIBITS);
247-
248-
table[b] = bits;
249188
b += 1;
250189
}
251190
table
252191
};
253192

254193
#[inline(always)]
255194
const fn next_state(st: u32, byte: u8) -> u32 {
256-
if USE_SHIFT32 {
257-
// SAFETY: `u64` is more aligned than `u32`, and has the same repr as `[u32; 2]`.
258-
let [lo, hi] = unsafe { mem::transmute::<u64, [u32; 2]>(TRANS_TABLE[byte as usize]) };
259-
#[cfg(target_endian = "big")]
260-
let (lo, hi) = (hi, lo);
261-
if st & 32 == 0 { lo } else { hi }.wrapping_shr(st)
262-
} else {
263-
TRANS_TABLE[byte as usize].wrapping_shr(st as _) as _
264-
}
195+
TRANS_TABLE[byte as usize].wrapping_shr(st)
265196
}
266197

267198
/// Check if `byte` is a valid UTF-8 first byte, assuming it must be a valid first or
@@ -407,13 +338,33 @@ fn run_utf8_validation_rt(bytes: &[u8]) -> Result<(), Utf8Error> {
407338
Ok(())
408339
}
409340

341+
// https://tools.ietf.org/html/rfc3629
342+
const UTF8_CHAR_WIDTH: &[u8; 256] = &[
343+
// 1 2 3 4 5 6 7 8 9 A B C D E F
344+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0
345+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 1
346+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2
347+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 3
348+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 4
349+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 5
350+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 6
351+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 7
352+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 8
353+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 9
354+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // A
355+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // B
356+
0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C
357+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // D
358+
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // E
359+
4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // F
360+
];
361+
410362
/// Given a first byte, determines how many bytes are in this UTF-8 character.
411363
#[unstable(feature = "str_internals", issue = "none")]
412364
#[must_use]
413365
#[inline]
414366
pub const fn utf8_char_width(b: u8) -> usize {
415-
// On 32-bit platforms, optimizer is smart enough to only load and operate on the high 32-bits.
416-
(TRANS_TABLE[b as usize] >> (64 - UTF8_LEN_HIBITS)) as usize
367+
UTF8_CHAR_WIDTH[b as usize] as usize
417368
}
418369

419370
/// Mask of the value bits of a continuation byte.

0 commit comments

Comments
 (0)