Light mode Light mode Dark mode Dark mode

Fast RSA

ACSC 2025 - Final

Category: crypto

Description #

Hey choom! I got my hands on some fresh new hardware, a chip that can do military-grade encryption at lightning speed.

I'm testing out it's signature features, and since you're a god-tier netrunner, maybe you can learn how to sign stuff yourself if you watch my bot do it enough?

Attachments: [ fast-rsa.zip ]

Problem #

The program generates an RSA keypair and outputs the modulus \(n\) and the uppermost 32 bits of \(d\).

It then repeatedly reads \(m, b \in \mathbb Z\) and calculates the signature \(s = m^d \mod n\). This is done in Montgomery form. Then \(s\) and whether any of the two multiplications in the \(b\)th step of exponentiation required the conditional subtraction is revealed.

The goal is to determine \(d\).

Solution #

Disclosure: I was not able to solve this challenge at the competition.

We can almost use an attack by Walter & Thompson1. You can find my implementation here.

The only difference is that in each step it is not known which of the two subtractions were needed, i. e. the observation vector is “squashed”. The solution is simple: the algorithm has to be run on multiple squashed observation vectors simultaneously. In practice 2 are enough. The algorithm calculates \(d\) in ~1 s, however getting the observation vectors from the remote takes considerably more time.

Implementation #

#!/usr/bin/env python3

import sys
import time
import pwn

sys.setrecursionlimit(2500)

# https://en.wikipedia.org/wiki/Montgomery_modular_multiplication
class Montgomery:
    def __init__(self, n):
        assert(n & 1)
        self.N = n
        self.nbits = self.N.bit_length()
        self.R = 1 << self.nbits
        self.R_inv = pow(self.R, -1, self.N)
        self.N_prime = -pow(self.N, -1, self.R)

    def to_montgomery(self, x):
        return (x << self.nbits) % self.N

    def from_montgomery(self, x):
        return (x * self.R_inv) % self.N

    def multiply(self, x, y):
        T = x * y
        m = ((T % self.R) * self.N_prime) % self.R
        t = (T + m * self.N) >> self.nbits
        return (t, False) if t < self.N else (t - self.N, True)

def parse_res(res, a, b):
    return int(res.split(a)[1].split(b)[0][2:], 16)

CUTOFF = 3

def Z_prime_next(M, A_bar, Y, bit):
    Y, z1 = M.multiply(Y, Y)
    z2 = False
    if bit:
        Y, z2 = M.multiply(Y, A_bar)
    return Y, z1 or z2

def attack_bruteforce(M, A_bar, Y, S_bar, dt):
    for _ in range(CUTOFF):
        Y, z1 = M.multiply(Y, Y)

    for i in range(1 << CUTOFF):
        if Y == S_bar:
            return True, dt + i
        Y, _ = M.multiply(Y, A_bar)

    return False, None

# inv: first t bits of dt are assumed to be correct
def attack_impl(M, As_bar, Ss_bar, Zs, Ys, Zs_prime, t, dt, H, guess = None):
    if not t % 10:
        print('attack_impl', t, dt)

    nbits = len(Zs[0])
    if guess != None:
        t += 1
        dt = dt | (guess << nbits - t)
        tmp = [Z_prime_next(M, A_bar, Y, guess) for A_bar, Y in zip(As_bar, Ys)]
        Ys = [Y for Y, _ in tmp]
        Zs_prime = [Z_prime + [z] for Z_prime, (_, z) in zip(Zs_prime, tmp)]

    if t == nbits - CUTOFF:
        return attack_bruteforce(M, As_bar[0], Ys[0], Ss_bar[0], dt)

    if any([Z[0:t] != Z_prime[0:t] for Z, Z_prime in zip(Zs, Zs_prime)]):
        return False, None

    if t < 32:
        guess = not not (H & (1 << (31 - t)))
        return attack_impl(M, As_bar, Ss_bar, Zs, Ys, Zs_prime, t, dt, H, guess)

    guess = int(any([Z[t-1] == Z_prime[t-1] for Z, Z_prime in zip(Zs, Zs_prime)]))
    s, d = attack_impl(M, As_bar, Ss_bar, Zs, Ys, Zs_prime, t, dt, H, guess)
    if not s:
        s, d = attack_impl(M, As_bar, Ss_bar, Zs, Ys, Zs_prime, t, dt, H, not guess)
    return s, d

def attack(N, As, Ss, Zs, d_upper):
    M = Montgomery(N)
    As_bar = [M.to_montgomery(A) for A in As]
    Ss_bar = [M.to_montgomery(S) for S in Ss]
    Ys = [M.to_montgomery(1) for _ in As]
    Zs_prime = [[] for _ in As]
    s, d = attack_impl(M, As_bar, Ss_bar, Zs, Ys, Zs_prime, 0, 0, d_upper)
    assert(s)
    return d

def get_sig(A):
    p.sendline(hex(A).encode() + b'\n0')
    return parse_res(p.readuntil(b'challenge: '), b'Signature: ', b'\n')

def is_expensive(res):
    GOOD = [
        b' is a cool number',
        b' is really quite something',
        b' likes to save me money'
    ]
    return not any([e in res for e in GOOD])

def get_vec(A, bits):
    BATCH = 32
    Z = []
    for i in range(0, bits, BATCH):
        print(i)
        lines = [hex(A) + '\n' + str(i + j) for j in range(BATCH)]
        p.sendline('\n'.join(lines).encode())

        for _ in range(BATCH):
            Z.append(is_expensive(p.readuntil(b'challenge: ')))
    return Z[::-1]

p = pwn.remote('127.0.0.1', 5000)
res = p.readuntil('challenge: ')
n = parse_res(res, b'n=', b'\n')
d_upper = parse_res(res, b'd: ', b'\n')         # crib
As = [2, 3]                                     # messages
Ss = [get_sig(A) for A in As]                   # signatures
Zs = [get_vec(A, n.bit_length()) for A in As]   # observation vectors
t1 = time.time()
d = attack(n, As, Ss, Zs, d_upper)
t2 = time.time()

p.sendline(b'')
msg = parse_res(p.readline(), b'message: ', b'\n')
sig = pow(msg, d, n)
p.sendline(hex(sig).encode())

print('{d_upper = }')
print(f'{As = }')
print(f'{Ss = }')
print(f'{n = }')
print(f'{d = }')
print(f'{msg = }')
print(f'{sig = }')
print(f'{CUTOFF = }')
print('# calculated in ', t2 - t1, 's')
print(p.readline())

References #

  1. C. D. Walter and S. Thompson, “Distinguishing exponent digits by observing modular subtractions,” in Topics in Cryptology — CT-RSA 2001, ser. Lecture Notes in Computer Science, vol. 2020. Springer Verlag, 2001, pp. 192–207. subsection 3.3