UltrafastSecp256k1 3.50.0
Ultra high-performance secp256k1 elliptic curve cryptography library
Loading...
Searching...
No Matches
ct_utils.hpp
Go to the documentation of this file.
1#ifndef SECP256K1_CT_UTILS_HPP
2#define SECP256K1_CT_UTILS_HPP
3
4// ============================================================================
5// Constant-Time Utilities -- High-Level API
6// ============================================================================
7// Provides byte-level constant-time operations for use in protocol
8// implementations (ECDSA, Schnorr, ECDH, etc.).
9//
10// All functions have fixed execution time regardless of input values.
11// No secret-dependent branches or memory access patterns.
12//
13// Audit Status:
14// - Value barriers: compiler optimization fence via inline asm / volatile
15// - Mask generation: arithmetic (no branches)
16// - Conditional ops: bitwise (no cmov -- explicit XOR/AND/OR)
17// - Table lookup: full scan (no early exit)
18//
19// Layers:
20// ct/ops.hpp -- 64-bit primitives (value_barrier, masks, cmov, cswap)
21// ct/field.hpp -- FieldElement CT ops (add/sub/mul/sqr/inv)
22// ct/scalar.hpp -- Scalar CT ops (add/sub/neg/cmov/cswap)
23// ct/point.hpp -- Point CT ops (complete addition, CT scalar_mul)
24// ct_utils.hpp -- THIS FILE: byte-level utilities for protocols
25// ============================================================================
26
27#include <cstdint>
28#include <cstddef>
29#include <cstring>
30#include <array>
31#include "secp256k1/ct/ops.hpp"
32
33namespace secp256k1::ct {
34
35// -- Byte-level Constant-Time Compare -----------------------------------------
36// Returns true if a[0..len) == b[0..len). Constant-time (no early exit).
37inline bool ct_equal(const void* a, const void* b, std::size_t len) noexcept {
38 const auto* pa = static_cast<const std::uint8_t*>(a);
39 const auto* pb = static_cast<const std::uint8_t*>(b);
40
41 std::uint64_t diff = 0;
42 // Process 8 bytes at a time
43 std::size_t i = 0;
44 for (; i + 8 <= len; i += 8) {
45 std::uint64_t va = 0, vb = 0;
46 std::memcpy(&va, pa + i, 8);
47 std::memcpy(&vb, pb + i, 8);
48 diff |= va ^ vb;
49 }
50 // Remaining bytes
51 for (; i < len; ++i) {
52 diff |= static_cast<std::uint64_t>(pa[i] ^ pb[i]);
53 }
54
55 return is_zero_mask(diff) != 0;
56}
57
58// Template version for fixed-size arrays
59template<std::size_t N>
60inline bool ct_equal(const std::array<std::uint8_t, N>& a,
61 const std::array<std::uint8_t, N>& b) noexcept {
62 return ct_equal(a.data(), b.data(), N);
63}
64
65// -- Byte-level Constant-Time Conditional Copy --------------------------------
66// if (flag) memcpy(dst, src, len); -- constant-time
67inline void ct_memcpy_if(void* dst, const void* src, std::size_t len,
68 bool flag) noexcept {
69 auto mask = bool_to_mask(flag);
70 auto mask8 = static_cast<std::uint8_t>(mask & 0xFF);
71
72 auto* pd = static_cast<std::uint8_t*>(dst);
73 const auto* ps = static_cast<const std::uint8_t*>(src);
74
75 for (std::size_t i = 0; i < len; ++i) {
76 pd[i] ^= (pd[i] ^ ps[i]) & mask8;
77 }
78}
79
80// -- Constant-Time Conditional Swap -------------------------------------------
81// if (flag) swap(a[0..len), b[0..len)); -- constant-time
82inline void ct_memswap_if(void* a, void* b, std::size_t len,
83 bool flag) noexcept {
84 auto mask = bool_to_mask(flag);
85 auto mask8 = static_cast<std::uint8_t>(mask & 0xFF);
86
87 auto* pa = static_cast<std::uint8_t*>(a);
88 auto* pb = static_cast<std::uint8_t*>(b);
89
90 for (std::size_t i = 0; i < len; ++i) {
91 std::uint8_t const diff = (pa[i] ^ pb[i]) & mask8;
92 pa[i] ^= diff;
93 pb[i] ^= diff;
94 }
95}
96
97// -- Constant-Time Zero Check -------------------------------------------------
98// Returns true if all bytes are zero. Constant-time.
99inline bool ct_is_zero(const void* data, std::size_t len) noexcept {
100 const auto* p = static_cast<const std::uint8_t*>(data);
101 std::uint64_t acc = 0;
102 for (std::size_t i = 0; i < len; ++i) {
103 acc |= static_cast<std::uint64_t>(p[i]);
104 }
105 return is_zero_mask(acc) != 0;
106}
107
108template<std::size_t N>
109inline bool ct_is_zero(const std::array<std::uint8_t, N>& data) noexcept {
110 return ct_is_zero(data.data(), N);
111}
112
113// -- Constant-Time Memory Set -------------------------------------------------
114// Guaranteed not to be optimized away by the compiler.
115inline void ct_memzero(void* data, std::size_t len) noexcept {
116 auto* p = static_cast<volatile std::uint8_t*>(data);
117 for (std::size_t i = 0; i < len; ++i) {
118 p[i] = 0;
119 }
120#if defined(__GNUC__) || defined(__clang__)
121 asm volatile("" : : "r"(data) : "memory");
122#endif
123}
124
125// -- Constant-Time Byte Select ------------------------------------------------
126// Returns a if flag is true, b otherwise. No branch.
127inline std::uint8_t ct_select_byte(std::uint8_t a, std::uint8_t b,
128 bool flag) noexcept {
129 auto mask = bool_to_mask(flag);
130 auto mask8 = static_cast<std::uint8_t>(mask & 0xFF);
131 return static_cast<std::uint8_t>((a & mask8) | (b & ~mask8));
132}
133
134// -- Constant-Time Lexicographic Compare --------------------------------------
135// Returns: -1 if a < b, 0 if a == b, 1 if a > b. Fully branchless.
136//
137// For the common 32-byte case: fully unrolled, no loops, no loop-carried
138// dependencies. This prevents the compiler from inserting branches on
139// the "decided" flag to short-circuit iterations.
140
141namespace ct_compare_detail {
142
143// Branchless unsigned compare: returns {gt, lt} where each is 0 or 1.
144// value_barrier on BOTH inputs prevents Clang from inserting beq/bne
145// branches before the sltu instructions (observed on RISC-V Clang 21).
146inline void ct_cmp_pair(std::uint64_t wa, std::uint64_t wb,
147 std::uint64_t& gt, std::uint64_t& lt) noexcept {
150#if defined(__riscv) && (__riscv_xlen == 64)
151 asm volatile("sltu %0, %2, %1" : "=r"(gt) : "r"(wa), "r"(wb));
152 asm volatile("sltu %0, %2, %1" : "=r"(lt) : "r"(wb), "r"(wa));
153#else
154 // Fully arithmetic comparison: extracts the borrow bit from
155 // unsigned subtraction. Avoids x86 seta/setb which read FLAGS --
156 // some Intel uarchs have data-dependent latency for FLAG reads
157 // depending on the flag state (CF+ZF pattern).
158 {
159 std::uint64_t const diff = wa - wb;
160 lt = (wa ^ ((wa ^ wb) | (diff ^ wa))) >> 63;
161 }
162 {
163 std::uint64_t const diff = wb - wa;
164 gt = (wb ^ ((wb ^ wa) | (diff ^ wb))) >> 63;
165 }
166#endif
167 // Barrier outputs to prevent the compiler from seeing gt==lt==0
168 // when inputs are equal and converting downstream code into branches.
171}
172
173// Load 8 bytes + bswap for lexicographic order.
174// On RISC-V: may_alias avoids GCC decomposing memcpy into 8x lbu + sb chain.
175// Callers MUST pass 8-byte-aligned pointers (true for all ct_compare paths:
176// hash outputs, key data, heap-allocated test buffers are always aligned).
177inline std::uint64_t ct_load_be(const std::uint8_t* p) noexcept {
178#if defined(__riscv) && (__riscv_xlen == 64)
179 typedef std::uint64_t u64_alias __attribute__((__may_alias__));
180 std::uint64_t v = *reinterpret_cast<const u64_alias*>(p);
181 return __builtin_bswap64(v);
182#else
183 std::uint64_t v = 0;
184 std::memcpy(&v, p, 8);
185#if defined(__GNUC__) || defined(__clang__)
186 return __builtin_bswap64(v);
187#elif defined(_MSC_VER)
188 return _byteswap_uint64(v);
189#else
190 // Generic fallback
191 return ((v >> 56) & 0xFF) | ((v >> 40) & 0xFF00) |
192 ((v >> 24) & 0xFF0000) | ((v >> 8) & 0xFF000000) |
193 ((v << 8) & 0xFF00000000) | ((v << 24) & 0xFF0000000000) |
194 ((v << 40) & 0xFF000000000000) | ((v << 56));
195#endif
196#endif
197}
198
199} // namespace ct_compare_detail
200
201inline int ct_compare(const void* a, const void* b, std::size_t len) noexcept {
202 const auto* pa = static_cast<const std::uint8_t*>(a);
203 const auto* pb = static_cast<const std::uint8_t*>(b);
204
205 // ---- Fast path: 32 bytes (fully unrolled, zero branches) ----
206 // Algorithm: reverse-scan accumulation.
207 // Process words 3->2->1->0 (least significant first).
208 // Each differing word OVERRIDES the running result.
209 // Final result reflects the FIRST (most significant) differing word.
210 // value_barrier after every step prevents Clang from injecting
211 // beq/bne branches (observed with Clang 21 RISC-V).
212 if (len == 32) {
213 using namespace ct_compare_detail;
214
215 // Load all 4 word pairs in big-endian (lexicographic order)
216 const std::uint64_t w0a = ct_load_be(pa + 0), w0b = ct_load_be(pb + 0);
217 const std::uint64_t w1a = ct_load_be(pa + 8), w1b = ct_load_be(pb + 8);
218 const std::uint64_t w2a = ct_load_be(pa + 16), w2b = ct_load_be(pb + 16);
219 const std::uint64_t w3a = ct_load_be(pa + 24), w3b = ct_load_be(pb + 24);
220
221 std::uint64_t result = 0;
222
223 // Word 3 (bytes 24-31, least significant)
224 {
225 std::uint64_t gt = 0, lt = 0;
226 ct_cmp_pair(w3a, w3b, gt, lt);
227 std::uint64_t differs = gt | lt; // 0 or 1
228 ct::value_barrier(differs);
229 std::uint64_t mask = 0ULL - differs;
230 ct::value_barrier(mask);
231 result = (gt - lt) & mask; // result was 0
232 }
233 ct::value_barrier(result);
234
235 // Word 2 (bytes 16-23)
236 {
237 std::uint64_t gt = 0, lt = 0;
238 ct_cmp_pair(w2a, w2b, gt, lt);
239 std::uint64_t differs = gt | lt;
240 ct::value_barrier(differs);
241 std::uint64_t mask = 0ULL - differs;
242 ct::value_barrier(mask);
243 result = ((gt - lt) & mask) | (result & ~mask);
244 }
245 ct::value_barrier(result);
246
247 // Word 1 (bytes 8-15)
248 {
249 std::uint64_t gt = 0, lt = 0;
250 ct_cmp_pair(w1a, w1b, gt, lt);
251 std::uint64_t differs = gt | lt;
252 ct::value_barrier(differs);
253 std::uint64_t mask = 0ULL - differs;
254 ct::value_barrier(mask);
255 result = ((gt - lt) & mask) | (result & ~mask);
256 }
257 ct::value_barrier(result);
258
259 // Word 0 (bytes 0-7, most significant -- overrides all)
260 {
261 std::uint64_t gt = 0, lt = 0;
262 ct_cmp_pair(w0a, w0b, gt, lt);
263 std::uint64_t differs = gt | lt;
264 ct::value_barrier(differs);
265 std::uint64_t mask = 0ULL - differs;
266 ct::value_barrier(mask);
267 result = ((gt - lt) & mask) | (result & ~mask);
268 }
269
270#if defined(__GNUC__) || defined(__clang__)
271 asm volatile("" : "+r"(result));
272#endif
273 return static_cast<int>(static_cast<std::int64_t>(result));
274 }
275
276 // ---- General path: arbitrary length ----
277 std::uint64_t result = 0;
278 std::uint64_t decided = 0;
279
280 std::size_t i = 0;
281 for (; i + 8 <= len; i += 8) {
282 std::uint64_t wa = ct_compare_detail::ct_load_be(pa + i);
283 std::uint64_t wb = ct_compare_detail::ct_load_be(pb + i);
284 // Barrier inputs: prevent compiler from comparing wa/wb directly
285 // (Clang 21 RISC-V inserts beq before sltu without these)
288 std::uint64_t const xor_val = wa ^ wb;
289 // nz = 1 if words differ, 0 otherwise
290#if defined(__riscv) && (__riscv_xlen == 64)
291 std::uint64_t nz;
292 asm volatile("snez %0, %1" : "=r"(nz) : "r"(xor_val));
293#else
294 std::uint64_t const nz = ((xor_val | (0ULL - xor_val)) >> 63) & 1ULL;
295#endif
296
297 // Barrier on decided only: prevent compiler from short-circuiting
298 ct::value_barrier(decided);
299
300 // take = 1 only for the very first differing word
301 std::uint64_t const take = nz & (1ULL - decided);
302
303 // mask = all-ones when take==1, zero when take==0
304 std::uint64_t mask = 0ULL - take;
305 ct::value_barrier(mask);
306
307 // Branchless unsigned compare: ct_cmp_pair-style barriers on inputs
308 std::uint64_t gt = 0, lt = 0;
311#if defined(__riscv) && (__riscv_xlen == 64)
312 asm volatile("sltu %0, %2, %1" : "=r"(gt) : "r"(wa), "r"(wb));
313 asm volatile("sltu %0, %2, %1" : "=r"(lt) : "r"(wb), "r"(wa));
314#else
315 gt = static_cast<std::uint64_t>(wa > wb);
316 lt = static_cast<std::uint64_t>(wa < wb);
317#endif
318 // diff_sign encodes: 1 = a>b, 0 = equal, -1 (0xFFFF...) = a<b
319 std::uint64_t const diff_sign = gt - lt;
320
321 result = (diff_sign & mask) | (result & ~mask);
322 decided |= nz;
323
324#if defined(__GNUC__) || defined(__clang__)
325 asm volatile("" : "+r"(result), "+r"(decided));
326#endif
327 }
328
329 // Remaining bytes (< 8) -- byte-by-byte fallback
330 for (; i < len; ++i) {
331 std::uint64_t ai = pa[i];
332 std::uint64_t bi = pb[i];
333 std::uint64_t const diff = ai ^ bi;
334
335#if defined(__riscv) && (__riscv_xlen == 64)
336 std::uint64_t nz;
337 asm volatile("snez %0, %1" : "=r"(nz) : "r"(diff));
338#else
339 std::uint64_t const nz = ((diff | (0ULL - diff)) >> 63) & 1ULL;
340#endif
341 ct::value_barrier(decided);
342 std::uint64_t const take = nz & (1ULL - decided);
343 std::uint64_t mask = 0ULL - take;
344 ct::value_barrier(mask);
345
346 std::uint64_t gt_b = 0, lt_b = 0;
349#if defined(__riscv) && (__riscv_xlen == 64)
350 asm volatile("sltu %0, %2, %1" : "=r"(gt_b) : "r"(ai), "r"(bi));
351 asm volatile("sltu %0, %2, %1" : "=r"(lt_b) : "r"(bi), "r"(ai));
352#else
353 gt_b = static_cast<std::uint64_t>(ai > bi);
354 lt_b = static_cast<std::uint64_t>(ai < bi);
355#endif
356 std::uint64_t const diff_sign = gt_b - lt_b;
357
358 result = (diff_sign & mask) | (result & ~mask);
359 decided |= nz;
360
361#if defined(__GNUC__) || defined(__clang__)
362 asm volatile("" : "+r"(result), "+r"(decided));
363#endif
364 }
365
366 // Normalise to {-1, 0, 1} without branches.
367 // result is 0, 1, or 0xFFFFFFFFFFFFFFFF (-1 as uint64)
368 return static_cast<int>(static_cast<std::int64_t>(result));
369}
370
371} // namespace secp256k1::ct
372
373#endif // SECP256K1_CT_UTILS_HPP
std::uint64_t ct_load_be(const std::uint8_t *p) noexcept
Definition ct_utils.hpp:177
void ct_cmp_pair(std::uint64_t wa, std::uint64_t wb, std::uint64_t &gt, std::uint64_t &lt) noexcept
Definition ct_utils.hpp:146
bool ct_is_zero(const void *data, std::size_t len) noexcept
Definition ct_utils.hpp:99
void ct_memswap_if(void *a, void *b, std::size_t len, bool flag) noexcept
Definition ct_utils.hpp:82
void ct_memzero(void *data, std::size_t len) noexcept
Definition ct_utils.hpp:115
void value_barrier(std::uint64_t &v) noexcept
Definition ops.hpp:94
void ct_memcpy_if(void *dst, const void *src, std::size_t len, bool flag) noexcept
Definition ct_utils.hpp:67
SECP256K1_CT_NO_STACK_PROTECTOR std::uint64_t bool_to_mask(bool flag) noexcept
Definition ops.hpp:143
std::uint8_t ct_select_byte(std::uint8_t a, std::uint8_t b, bool flag) noexcept
Definition ct_utils.hpp:127
bool ct_equal(const void *a, const void *b, std::size_t len) noexcept
Definition ct_utils.hpp:37
int ct_compare(const void *a, const void *b, std::size_t len) noexcept
Definition ct_utils.hpp:201
SECP256K1_CT_NO_STACK_PROTECTOR std::uint64_t is_zero_mask(std::uint64_t v) noexcept
Definition ops.hpp:107