Writeup Perfect Secrecy from Google CTF 2018 (Quals)

by: xomex (Sven H)

Challenge (74 solves, 158 points)

This crypto experiment will help you decrypt an RSA encrypted message.
(Attachment containing challenge.py, flag.txt, key_pub.pem)
nc perfect-secrecy.ctfcompetition.com 1337

Looking at the description and the given files we can guess, that flag.txt is the flag encrypted with RSA under key_pub.pem, which turns out to be correct. Furthermore we can guess that challenge.py is running on the server, which is also correct.

It seems our goal is to use the server to decrypt the flag for us.

So let’s take a look at challenge.py:

challenge.py

#!/usr/bin/env python3
import sys
import random

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend


def ReadPrivateKey(filename):
  return serialization.load_pem_private_key(
      open(filename, 'rb').read(), password=None, backend=default_backend())


def RsaDecrypt(private_key, ciphertext):
  assert (len(ciphertext) <=
          (private_key.public_key().key_size // 8)), 'Ciphertext too large'
  return pow(
      int.from_bytes(ciphertext, 'big'),
      private_key.private_numbers().d,
      private_key.public_key().public_numbers().n)


def Challenge(private_key, reader, writer):
  try:
    m0 = reader.read(1)
    m1 = reader.read(1)
    ciphertext = reader.read(private_key.public_key().key_size // 8)
    dice = RsaDecrypt(private_key, ciphertext)
    for rounds in range(100):
      p = [m0, m1][dice & 1]
      k = random.randint(0, 2)
      c = (ord(p) + k) % 2
      writer.write(bytes((c,)))
    writer.flush()
    return 0

  except Exception as e:
    return 1


def main():
  private_key = ReadPrivateKey(sys.argv[1])
  return Challenge(private_key, sys.stdin.buffer, sys.stdout.buffer)


if __name__ == '__main__':
  sys.exit(main())

There are some (more or less) uninteresting functions:

  • ReadPrivateKey simply reads a private key from a specified file (given as filename).
    • working correctly, has no user input
  • RsaDecrypt calculates c^d MOD N, with d being the private key and n being part of the public key. This is simple Textbook RSA decryption of a ciphertext c=m^e MOD N.
    The check (assert ...) in the beginning of the function can be ignored.
    • working correctly
  • main is also uninteresting

Challenge is where the fun begins:

At first two bytes are read: m0 and m1. Then a ciphertext is read. We can check key_size using the public key. It turns out to be 1024, thus the server reads 128 Bytes. The ciphertext is decrypted into dice.

Then the server outputs 100 Bytes as follows:
Depending on the least significant bit of dice, m0 or m1 is selected into p. The result of this is the same in each iteration. Then the byte is written: p+random.randint(0,2) MOD 2

This can be simplified as: Challenge reads two bytes my,m1 and a 128 Byte ciphertext c. The ciphertext is decrypted and the least significant bit is used to choose m0 or m1 into p. Then 100 Bytes are written, one by one: write( (p+random.randint(0,2)) %2)

Getting the Least Significant Bit

For simplicity we will pass m0=0 and m1=1 to the server. If we could determine whether the server chose 0 or 1 as p we could learn the least significant bit. But at first it seems, like we are given 100 random bytes (either 0 or 1).

So let’s take a look at the documentation of random.randint(a,b)

Return a random integer N such that a <= N <= b.

This means the server outputs (100 times): p+r MOD 2 with r in {0,1,2} (including 2!). Because p+0 MOD 2 = p+2 MOD 2 it is twice as likely to get p MOD 2 as it is to get p+1 MOD 2. Thus we can learn the least significant bit (LSB) of the plaintext as follows:

#!/usr/bin/env python2
#pip2 install pwntools
import binascii
from pwn import *

def itob(x):
    """transform number->hex->bytes"""
    ret = hex(x)[2:]
    if len(ret)%2!=0:
        ret = '0'+ret
    return binascii.unhexlify(ret)

def getLSBit(ciphertext):
    a0=0
    a1=0
    # just to counteract randomness we sometimes will ask more than once
    while abs(a0-a1)<20:
        r = remote('perfect-secrecy.ctfcompetition.com',1337)
        r.send('\x00')#m0
        r.send('\x01')#m1
        snd = itob(ciphertext)
        # pad ciphertext to 128 length
        snd = snd.zfill(128)
        r.send(snd)
        res = r.readall(timeout=2)
        # count 0 and 1
        a0+=res.count('\x00')
        a1+=res.count('\x01')
    #if more 1s than 0s return True(1)
    return a1>a0

getLSBit(c) gives us (with a good enough probability) the least significant bit of m=c^d MOD N. But a single bit is a bit (no pun intended) too short for a Flag, so we need a way to learn more about m.

Trying to learn the plaintext

Our Idea was to rightshift the plaintext by multiplying the ciphertext with (1/2)^e.(Note 1/2 is the multiplicative inverse of 2 MOD N)
c=m^e MOD N thus m^e * (1/2)^e MOD N = (m*1/2)^e MOD N

This would enable us to learn a bit, shift the text and learn a new bit (and repeat this). Sadly the shifting breaks as soon as the LSB is 1.

#!/usr/bin/env python2
# Just to show that shifting breaks on LSB=1
from Crypto.PublicKey import RSA

def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

r = RSA.generate(1024)
n = r.key.n
e = r.key.e
inv2 = modinv(2,n)
inv2 = long(inv2)

m = 16L
c = r.encrypt(m, None)[0]
print('m  : %s'%bin(r.decrypt(c)))
print('m/2: %s'%bin(r.decrypt(c * pow(inv2,e,n) % n)))

m = 15L
c = r.encrypt(m, None)[0]

print('m  : %s'%bin(r.decrypt(c)))
print('m/2: %s'%bin(r.decrypt(c * pow(inv2,e,n) % n)))

outputs:

m  : 0b10000
m/2: 0b1000
m  : 0b10001
m/2: 0b101010001100010001100101000001001100100110100110010001111001....

We tried to get this working for a while, but then moved on.

Using the LSB Oracle

The Theory

As our approach didn’t work, we googled for RSA least significant bit, and a crypto stackexchange post showed up.

For a RSA plaintext m it must hold that 0 < m < N, which implies 0 < 2m < 2N. (We can assume m!=0)
Furthermore as N is odd (otherwise it would be trivial to factor, and we can check using the given pubkey) and 2m is even (obviously), we can conclude:

  • 2m != N (this helps us to ignore some edge cases)
  • 2m MOD N odd iff N < 2m < 2N (iff N/2 < m < N)
    • N < 2m < 2N iff 2m MOD N = 2m-N. As 2m is even and N is odd, the result is odd.
  • 2m MOD N even iff 0 < 2m < N (iff 0 < m < N/2)

After this we know whether 0<m<N/2 or N/2<m<N. Let’s assume 0<m<N/2. We know 0<4m<2N. We now will check whether 4m MOD N is even or odd.

  • 4m MOD N odd iff N < 4m < 2N (N/4 < m < N/2)
  • 4m MOD N even iff 0 < 4m < N (0 < m < N/4)

Let’s assume N/2<m<N. We Know 2N < 4m <4N. We now will check whether 4m MOD N is even or odd.

  • 4m MOD N odd iff 3N < 4m < 4N (3/4 N < m < 4/4 N)
  • 4m MOD N even iff 2N < 4m < 3N (2/4 N < m < 3/4 N)

With the second check we can get tighter bounds around m. We will call them UB and LB (upper/lower bound). With each iteration we half the difference UB-LB, thus after key_size=1024 iterations we know m. This is similar to binary search.

Running it

As we now know what we can learn from 2m being even or odd, we need to implement this. We can simply test whether m is even or odd by checking the least significant bit. Now we only need to be able to multiply m with 2 and get the server to give us the lsb. This is easily possible by querying the server c*2^e MOD N = (m*2)^e MOD N. After we got that information, we adjust our bounds and query 4m. We continue this, until our Bounds have a difference of 1.

#!/usr/bin/env python2

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

if __name__ == '__main__':
    pubkey = serialization.load_pem_public_key(open('key_pub.pem', 'rb').read(), backend=default_backend())
    n = pubkey.public_numbers().n
    e = pubkey.public_numbers().e
    with open('flag.txt','rb') as f:
        c = int(f.read().encode('hex'),16)
    LB = 0
    UB = n
    p2 = pow(2,e,n)

    context(log_level='error')
    while UB-1>LB:
        M = (UB+LB)/2
        c = p2*c %n
        print('')
        print('%d'%(UB-LB).bit_length())
        print('LB: %s'%repr(itob(LB)))
        print('RB: %s'%repr(itob(UB)))
        if getLSBit(c):
            LB=M
        else:
            UB=M
    print('-'*20)
    print('Finished')
    print('LB: %s'%repr(itob(LB)))
    print('RB: %s'%repr(itob(UB)))

After a while we get:

[...]
--------------------
Finished
LB: "\x02a\xb4\x02\x0c3\xd2.:\xe7\x9eB'\xf2\xd5\x1c7`\xe4\r\xcd\xac\xd8\x7f\x02_L\x84q\xea\x8c\xb4\x1d}\x82\x90g\x170\x00.oO CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337} On\xae"
RB: "\x02a\xb4\x02\x0c3\xd2.:\xe7\x9eB'\xf2\xd5\x1c7`\xe4\r\xcd\xac\xd8\x7f\x02_L\x84q\xea\x8c\xb4\x1d}\x82\x90g\x170\x00.oO CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337} On\xaf"

So we got our flag.

TL;DR

Server acts as a least significant bit oracle, because random.randint(a,b) returns an int in [a,b] and not [a,b). This allows us to learn the plaintext.

CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337}

Full Script

#!/usr/bin/env python2
#pip2 install pwntools
import binascii
from pwn import *

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

def itob(x):
    """transform number->hex->bytes"""
    ret = hex(x)[2:]
    if len(ret)%2!=0:
        ret = '0'+ret
    return binascii.unhexlify(ret)

def getLSBit(ciphertext):
    a0=0
    a1=0
    # just to counteract randomness we sometimes will ask more than once
    while abs(a0-a1)<20:
        r = remote('perfect-secrecy.ctfcompetition.com',1337)
        r.send('\x00')#m0
        r.send('\x01')#m1
        snd = itob(ciphertext)
        # pad ciphertext to 128 length
        snd = snd.zfill(128)
        r.send(snd)
        res = r.readall(timeout=2)
        # count 0 and 1
        a0+=res.count('\x00')
        a1+=res.count('\x01')
    #if more 1s than 0s return True(1)
    return a1>a0


if __name__ == '__main__':
    pubkey = serialization.load_pem_public_key(open('key_pub.pem', 'rb').read(), backend=default_backend())
    n = pubkey.public_numbers().n
    e = pubkey.public_numbers().e
    with open('flag.txt','rb') as f:
        c = int(f.read().encode('hex'),16)
    LB = 0
    UB = n
    p2 = pow(2,e,n)

    context(log_level='error')
    while UB-1>LB:
        M = (UB+LB)/2
        c = p2*c %n
        print('')
        print('%d'%(UB-LB).bit_length())
        print('LB: %s'%repr(itob(LB)))
        print('RB: %s'%repr(itob(UB)))
        if getLSBit(c):
            LB=M
        else:
            UB=M
    print('-'*20)
    print('Finished')
    print('LB: %s'%repr(itob(LB)))
    print('RB: %s'%repr(itob(UB)))