Performance of ITE Expressions (incomplete)
June 14, 2021
A branch is an expensive operation even in modern CPUs because the computer will know which of the paths is taken only in the latest stages of the CPU pipeline.
In the meantime, the CPU stalls.
Modern CPUs use branch prediction, speculative execution and instruction reordering to minimize the impact of a branch.
They do a good job but still a branch is potentially expensive so they are replaced by branchless variants.
If-Then-Else or ITE for short, are symbolic expression that denotes a value chosen from two possible values based on a condition. These are the symbolic branch.
Naturally we could rewrite a symbolic ITE with a symbolic branchless expression.
The question is: which is better for a solver like Z3? Which makes the SMT/SAT solver faster?
After two weeks working on this post I still don’t have an answer but at least I know some unknowns.
Z3 If-Then-Else
In Z3 we use z3.If to build such symbolic expressions.
Take for example the following Python function xtime:
>>> def xtime(a):
... thenval = (((a << 1) ^ 0x1B) & 0xFF)
... elseval = (a << 1)
... condval = (a & 0x80)
... return thenval if condval else elseval
Symbolically, we could rewrite it as follows:
>>> from z3 import If, BitVec, simplify
>>> def xtime_branch(a):
... thenval = (((a << 1) ^ 0x1B) & 0xFF)
... elseval = (a << 1)
... condval = (a & 0x80)
... return If(condval != 0, thenval, elseval)
Remember that in Python, the (thenval) if (condval) else (elseval) is evaluated at runtime but in Z3 we cannot evaluate anything.
So we need to model the fact that the output of xtime it may be thenval or elseval, depending of the condition.
Let’s see what is the result of xtime_branch
>>> T = BitVec('T', 8)
>>> xtime_branch(T)
If(T & 128 != 0, (T << 1 ^ 27) & 255, T << 1)
>>> simplify(xtime_branch(T))
If(Extract(7, 7, T) == 0,
Concat(Extract(6, 0, T), 0),
Concat(Extract(6, 4, T),
~Extract(3, 2, T),
Extract(1, 1, T),
~Extract(0, 0, T),
1))
Before continuing, I would like to simplify xtime_branch a little:
- the input are always an 8 bits, so the
x & 0xFFmask is not needed - the
thenvalcan reuse theelseval
>>> def xtime_branch(a):
... elseval = (a << 1)
... thenval = (elseval ^ 0x1B)
... condval = (a & 0x80)
... return If(condval != 0, thenval, elseval)
>>> xtime_branch(T)
If(T & 128 != 0, T << 1 ^ 27, T << 1)
>>> simplify(xtime_branch(T))
If(Extract(7, 7, T) == 0,
Concat(Extract(6, 0, T), 0),
Concat(Extract(6, 4, T),
~Extract(3, 2, T),
Extract(1, 1, T),
~Extract(0, 0, T),
1))
As you see, this xtime_branch and the previous one yield the same result after applying z3.simplify.
However I’m going to keep those simplifications explicit in xtime_branch for further optimizations later.
Branchless ITE
The (a & 0x80) != 0 condition is equivalent to (a >> 7) != 0.
The key point to notice is that when (a & 0x80) != 0 then a >> 7 == 1; when (a & 0x80) == 0 then a >> 7 == 0.
With this single bit boolean we can get rid of the If doing a branchless bithack
>>> def xtime_branchless(a):
... elseval = (a << 1)
... thenval = (elseval ^ 0x1B)
... condval = (a >> 7) # it can be 0 or 1
... return elseval ^ ((thenval ^ elseval) & -(condval))
>>> xtime_branchless(T)
T << 1 ^ (T << 1 ^ 27 ^ T << 1) & -(T >> 7)
>>> simplify(xtime_branchless(T))
Concat(Extract(6, 4, T),
Extract(3, 2, T) ^ Extract(4, 3, 255*(T >> 7)),
Extract(1, 1, T),
Extract(0, 0, T) ^
Extract(1, 1, 3*Extract(1, 0, T >> 7)),
Extract(0, 0, T >> 7))
We don’t longer have an ITE expression!
But there is a catch…
Bit broadcasting
The catch is that we have some multiplications:
255*(T >> 7)3*Extract(1, 0, T >> 7)
These come from -(condval).
When condval is 0, then -(condval) is 0, represented as eight 0 bits, the ((thenval ^ elseval) & -(condval)) goes to 0 and the expression reduces to the left part of the main xor: elseval.
When condval is 1, then -(condval) is 1, represented as eight 1 bits because in Z3 (and it a lot of other languages), the negative numbers are in 2-complement representation.
This 1 bits mask allows the right side to be xor’d with the left side elseval ^ thenval ^ elseval that reduces to thenval.
This why the branchless bithack works and more over, from where those multiplications come: from the 2-complement.
z3.simplify was not smart enough to broadcasting the least significant bit of a >> 7.
We could do it better broadcasting the most significant bit of a and build the condition mask directly:
>>> from z3 import Extract, Concat
>>> def xtime_broadcasted(a):
... elseval = (a << 1)
... thenval = (elseval ^ 0x1B)
... msb = Extract(7, 7, a)
... condmask = Concat(*([msb] * 8)) # broadcast a single bit to 8 bits
... return elseval ^ ((thenval ^ elseval) & condmask)
>>> xtime_broadcasted(T)
T << 1 ^
(T << 1 ^ 27 ^ T << 1) &
Concat(Concat(Concat(Concat(Concat(Concat(Concat(Extract(7,
7,
T),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T))
>>> simplify(xtime_broadcasted(T))
Concat(Extract(6, 4, T),
Extract(3, 3, T) ^ Extract(7, 7, T),
Extract(2, 2, T) ^ Extract(7, 7, T),
Extract(1, 1, T),
Extract(0, 0, T) ^ Extract(7, 7, T),
Extract(7, 7, T))
Ugly but once simplified with z3.simplify, xtime_broadcasted seems to be quite simple: only bit picking and xor.
One last hack
xtime_broadcasted can be simplified further canceling the elseval from thenval ^ elseval because thenval == elseval & 0x1B
So elseval ^ ((thenval ^ elseval) & condmask) reduces to elseval ^ (0x1B & condmask):
>>> def xtime_cancelled(a):
... elseval = (a << 1)
... msb = Extract(7, 7, a)
... condmask = Concat(*([msb] * 8)) # broadcast a single bit to 8 bits
... return elseval ^ (0x1B & condmask)
>>> xtime_cancelled(T)
T << 1 ^
27 &
Concat(Concat(Concat(Concat(Concat(Concat(Concat(Extract(7,
7,
T),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T)),
Extract(7, 7, T))
>>> simplify(xtime_cancelled(T))
Concat(Extract(6, 4, T),
Extract(3, 3, T) ^ Extract(7, 7, T),
Extract(2, 2, T) ^ Extract(7, 7, T),
Extract(1, 1, T),
Extract(0, 0, T) ^ Extract(7, 7, T),
Extract(7, 7, T))
Note how z3.simplify was smart enough to do the cancellation automatically by itself: once simplified by Z3, xtime_broadcasted and xtime_cancelled are the same.
Correctness of xtime*
Let’s verify that we didn’t screw up.
The search space is only \(2^8\) so we can prove if the xtime_X works comparing it with the original xtime for all the possible inputs.
>>> from z3 import Solver, And, Or, BitVec
>>> a = BitVec('a', 8)
>>> solver = Solver()
>>> full_search = [And(a == i, xtime_branch(a) == xtime(i)) for i in range(256)]
>>> solver.check(Or(*full_search))
sat
>>> full_search = [And(a == i, xtime_branchless(a) == xtime(i)) for i in range(256)]
>>> solver.check(Or(*full_search))
sat
>>> full_search = [And(a == i, xtime_cancelled(a) == xtime(i)) for i in range(256)]
>>> solver.check(Or(*full_search))
sat
>>> full_search = [And(a == i, xtime_broadcasted(a) == xtime(i)) for i in range(256)]
>>> solver.check(Or(*full_search))
sat
Everything is in order.
Experiments setup
The 4 functions were tested in 4 different experiments or scenarios:
null_experiment: an 8-bit vector and a simple bitmask operation on it without usingxtime*. Intended to see the performance of Z3 in a trivial case.single_bitvec_experiment: a call toxtime*on an 8-bit vector and the verification of the results testing 256 possible values.mix_two_bitvec_experiment: callxtime*twice on two 8-bit vectors, perform a few bitmask operations on them and verify the correctness doing a full search of the 65536 possible values.encrypt_rounds_experiment: callxtime*several times doing several bitmask and shift operations on 32 8-bit vectors. This represents a simplified version of a single round of the AES cipher.
For each experiment, each xtime* function was tested using the simplified and not-simplified variants.
Each experiment consisted in create and setup a new z3.Solver with its own z3.Context and measure the time that it took checking the model: the check-elapsed time.
Because Z3 is not deterministic, we ran each experiment at least 20 times with a maximum of 100 times and collected not only the check-elapsed time but also the statistics of the solver provided by Z3 with solver.statistics().
The null_experiment actually does not use the xtime* function and it is used to have an idea of how small the check-elapsed time can be.
Experiments results
The first thing that we can see is how each xtime* performed in each experiment.
check-elapsed time (y axis) per xtime* function (x axis). Each subplot corresponds to a different experiment.
A few remarks:
- The
null_experimentshows a quite stable plot regardless of thextime*used as expected. - For
single_bitvec_experimentandmix_two_bitvec_experimentthere is little difference ifxtime*was simplified or not but it really made a difference for theencrypt_rounds_experiment. - The ITE expression of
xtime_branchperformed better than the others insingle_bitvec_experimentbut it was as twice as slow inmix_two_bitvec_experiment. Why? - The
encrypt_rounds_experimentshows some weird results: a simplifiedxtime_branchlessis incredibly slow while the non-simplified version is incredibly fast, even faster than the rest. - Moreover, in
encrypt_rounds_experimentthe simplifiedxtime_broadcastedandxtime_cancelledhave different performance but as we shown before, they are the same!
This last item makes me thing, are we seeing an outlier affecting the mean?
We can rule that out measuring the minimum instead of the mean.
check-elapsed time (y axis) per xtime* function (x axis). Each subplot corresponds to a different experiment. Note how the plot has the same shape than before.
Nope, same thing.
Could be this discrepancy be just by luck? We need a measure independent from the time and Z3 tracks several statistics for that.
It’s unclear what they mean however.
Exploring a little it seems that there is a relationship between 'added eqs' and the elapsed time.
check() took and the amount of added eqs. They follow almost a perfect linear relationship.
Let’s see how many eqs were added in the encrypt_rounds_experiment:
added eqs (y axis) per xtime* function (x axis). Each subplot corresponds to a different experiment. Note how the plot has the same shape than before showing a strong relationship between added eqs and the check-elapsed time.
Same shape that before: for some reason Z3 added more eqs in xtime_broadcasted than in xtime_cancelled (both simplified) even if both are the same Z3 expressions.
So the discrepancy is not due the noise: Z3 indeed saw these two as different things.
Code and results
- Experiments (Python code)
- Plotting (Python code)
- Runtime results (Pandas DataFrame in Parquet format)
- Z3 stats (Pandas DataFrame in Parquet format)
Conclusions
None.
I’m still missing a lot of pieces of this puzzle.
Related tags: z3, smt, sat, solver, if ITE, bithack, performance

