This crypto experiment will help you decrypt an RSA encrypted message.

(Attachment containing challenge.py, flag.txt, key_pub.pem)

`nc perfect-secrecy.ctfcompetition.com 1337`

Looking at the description and the given files we can guess, that `flag.txt`

is the flag encrypted with RSA under `key_pub.pem`

, which turns out to be correct.
Furthermore we can guess that `challenge.py`

is running on the server, which is also correct.

It seems our goal is to use the server to decrypt the flag for us.

So let’s take a look at `challenge.py`

:

```
#!/usr/bin/env python3
import sys
import random
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
def ReadPrivateKey(filename):
return serialization.load_pem_private_key(
open(filename, 'rb').read(), password=None, backend=default_backend())
def RsaDecrypt(private_key, ciphertext):
assert (len(ciphertext) <=
(private_key.public_key().key_size // 8)), 'Ciphertext too large'
return pow(
int.from_bytes(ciphertext, 'big'),
private_key.private_numbers().d,
private_key.public_key().public_numbers().n)
def Challenge(private_key, reader, writer):
try:
m0 = reader.read(1)
m1 = reader.read(1)
ciphertext = reader.read(private_key.public_key().key_size // 8)
dice = RsaDecrypt(private_key, ciphertext)
for rounds in range(100):
p = [m0, m1][dice & 1]
k = random.randint(0, 2)
c = (ord(p) + k) % 2
writer.write(bytes((c,)))
writer.flush()
return 0
except Exception as e:
return 1
def main():
private_key = ReadPrivateKey(sys.argv[1])
return Challenge(private_key, sys.stdin.buffer, sys.stdout.buffer)
if __name__ == '__main__':
sys.exit(main())
```

There are some (more or less) uninteresting functions:

`ReadPrivateKey`

simply reads a private key from a specified file (given as filename).- working correctly, has no user input

`RsaDecrypt`

calculates`c^d MOD N`

, with`d`

being the private key and`n`

being part of the public key. This is simple Textbook RSA decryption of a ciphertext`c=m^e MOD N`

.

The check (`assert ...`

) in the beginning of the function can be ignored.- working correctly

`main`

is also uninteresting

`Challenge`

is where the fun begins:

At first two bytes are read: `m0`

and `m1`

.
Then a ciphertext is read. We can check `key_size`

using the public key. It turns out to be 1024, thus the server reads 128 Bytes.
The ciphertext is decrypted into `dice`

.

Then the server outputs 100 Bytes as follows:

Depending on the least significant bit of `dice`

, `m0`

or `m1`

is selected into `p`

.
The result of this is the same in each iteration.
Then the byte is written: `p+random.randint(0,2) MOD 2`

This can be simplified as:
`Challenge`

reads two bytes `my,m1`

and a 128 Byte ciphertext `c`

.
The ciphertext is decrypted and the least significant bit is used to choose `m0`

or `m1`

into `p`

.
Then 100 Bytes are written, one by one: `write( (p+random.randint(0,2)) %2)`

For simplicity we will pass `m0=0`

and `m1=1`

to the server.
If we could determine whether the server chose 0 or 1 as `p`

we could learn the least significant bit.
But at first it seems, like we are given 100 random bytes (either 0 or 1).

So let’s take a look at the documentation of `random.randint(a,b)`

Return a random integer N such that a <= N <= b.

This means the server outputs (100 times): `p+r MOD 2`

with `r in {0,1,2}`

(including 2!).
Because `p+0 MOD 2 = p+2 MOD 2`

it is twice as likely to get `p MOD 2`

as it is to get `p+1 MOD 2`

.
Thus we can learn the least significant bit (LSB) of the plaintext as follows:

```
#!/usr/bin/env python2
#pip2 install pwntools
import binascii
from pwn import *
def itob(x):
"""transform number->hex->bytes"""
ret = hex(x)[2:]
if len(ret)%2!=0:
ret = '0'+ret
return binascii.unhexlify(ret)
def getLSBit(ciphertext):
a0=0
a1=0
# just to counteract randomness we sometimes will ask more than once
while abs(a0-a1)<20:
r = remote('perfect-secrecy.ctfcompetition.com',1337)
r.send('\x00')#m0
r.send('\x01')#m1
snd = itob(ciphertext)
# pad ciphertext to 128 length
snd = snd.zfill(128)
r.send(snd)
res = r.readall(timeout=2)
# count 0 and 1
a0+=res.count('\x00')
a1+=res.count('\x01')
#if more 1s than 0s return True(1)
return a1>a0
```

`getLSBit(c)`

gives us (with a good enough probability) the least significant bit of `m=c^d MOD N`

.
But a single bit is a bit (no pun intended) too short for a Flag, so we need a way to learn more about `m`

.

Our Idea was to rightshift the plaintext by multiplying the ciphertext with `(1/2)^e`

.(Note `1/2`

is the multiplicative inverse of `2 MOD N`

)

`c=m^e MOD N`

thus `m^e * (1/2)^e MOD N = (m*1/2)^e MOD N`

This would enable us to learn a bit, shift the text and learn a new bit (and repeat this). Sadly the shifting breaks as soon as the LSB is 1.

```
#!/usr/bin/env python2
# Just to show that shifting breaks on LSB=1
from Crypto.PublicKey import RSA
def egcd(a, b):
if a == 0:
return (b, 0, 1)
else:
g, y, x = egcd(b % a, a)
return (g, x - (b // a) * y, y)
def modinv(a, m):
g, x, y = egcd(a, m)
if g != 1:
raise Exception('modular inverse does not exist')
else:
return x % m
r = RSA.generate(1024)
n = r.key.n
e = r.key.e
inv2 = modinv(2,n)
inv2 = long(inv2)
m = 16L
c = r.encrypt(m, None)[0]
print('m : %s'%bin(r.decrypt(c)))
print('m/2: %s'%bin(r.decrypt(c * pow(inv2,e,n) % n)))
m = 15L
c = r.encrypt(m, None)[0]
print('m : %s'%bin(r.decrypt(c)))
print('m/2: %s'%bin(r.decrypt(c * pow(inv2,e,n) % n)))
```

outputs:

```
m : 0b10000
m/2: 0b1000
m : 0b10001
m/2: 0b101010001100010001100101000001001100100110100110010001111001....
```

We tried to get this working for a while, but then moved on.

As our approach didn’t work, we googled for *RSA least significant bit*, and a crypto stackexchange post showed up.

For a RSA plaintext `m`

it must hold that `0 < m < N`

, which implies `0 < 2m < 2N`

. (We can assume `m!=0`

)

Furthermore as `N`

is odd (otherwise it would be trivial to factor, and we can check using the given pubkey) and `2m`

is even (obviously), we can conclude:

`2m != N`

(this helps us to ignore some edge cases)`2m MOD N`

odd iff`N < 2m < 2N`

(iff`N/2 < m < N`

)`N < 2m < 2N`

iff`2m MOD N = 2m-N`

. As`2m`

is even and`N`

is odd, the result is odd.

`2m MOD N`

even iff`0 < 2m < N`

(iff`0 < m < N/2`

)

After this we know whether `0<m<N/2`

or `N/2<m<N`

.
Let’s assume `0<m<N/2`

.
We know `0<4m<2N`

.
We now will check whether `4m MOD N`

is even or odd.

`4m MOD N`

odd iff`N < 4m < 2N`

(`N/4 < m < N/2`

)`4m MOD N`

even iff`0 < 4m < N`

(`0 < m < N/4`

)

Let’s assume `N/2<m<N`

.
We Know `2N < 4m <4N`

.
We now will check whether `4m MOD N`

is even or odd.

`4m MOD N`

odd iff`3N < 4m < 4N`

(`3/4 N < m < 4/4 N`

)`4m MOD N`

even iff`2N < 4m < 3N`

(`2/4 N < m < 3/4 N`

)

With the second check we can get tighter bounds around `m`

.
We will call them `UB`

and `LB`

(upper/lower bound).
With each iteration we half the difference `UB-LB`

, thus after `key_size=1024`

iterations we know `m`

.
This is similar to binary search.

As we now know what we can learn from `2m`

being even or odd, we need to implement this.
We can simply test whether `m`

is even or odd by checking the least significant bit.
Now we only need to be able to multiply `m`

with 2 and get the server to give us the lsb.
This is easily possible by querying the server `c*2^e MOD N = (m*2)^e MOD N`

.
After we got that information, we adjust our bounds and query `4m`

.
We continue this, until our Bounds have a difference of 1.

```
#!/usr/bin/env python2
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
if __name__ == '__main__':
pubkey = serialization.load_pem_public_key(open('key_pub.pem', 'rb').read(), backend=default_backend())
n = pubkey.public_numbers().n
e = pubkey.public_numbers().e
with open('flag.txt','rb') as f:
c = int(f.read().encode('hex'),16)
LB = 0
UB = n
p2 = pow(2,e,n)
context(log_level='error')
while UB-1>LB:
M = (UB+LB)/2
c = p2*c %n
print('')
print('%d'%(UB-LB).bit_length())
print('LB: %s'%repr(itob(LB)))
print('RB: %s'%repr(itob(UB)))
if getLSBit(c):
LB=M
else:
UB=M
print('-'*20)
print('Finished')
print('LB: %s'%repr(itob(LB)))
print('RB: %s'%repr(itob(UB)))
```

After a while we get:

```
[...]
--------------------
Finished
LB: "\x02a\xb4\x02\x0c3\xd2.:\xe7\x9eB'\xf2\xd5\x1c7`\xe4\r\xcd\xac\xd8\x7f\x02_L\x84q\xea\x8c\xb4\x1d}\x82\x90g\x170\x00.oO CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337} On\xae"
RB: "\x02a\xb4\x02\x0c3\xd2.:\xe7\x9eB'\xf2\xd5\x1c7`\xe4\r\xcd\xac\xd8\x7f\x02_L\x84q\xea\x8c\xb4\x1d}\x82\x90g\x170\x00.oO CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337} On\xaf"
```

So we got our flag.

Server acts as a least significant bit oracle, because `random.randint(a,b)`

returns an int in `[a,b]`

and not `[a,b)`

.
This allows us to learn the plaintext.

`CTF{h3ll0__17_5_m3_1_w45_w0nd3r1n6_1f_4f73r_4ll_7h353_y34r5_y0u_d_l1k3_70_m337}`

```
#!/usr/bin/env python2
#pip2 install pwntools
import binascii
from pwn import *
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
def itob(x):
"""transform number->hex->bytes"""
ret = hex(x)[2:]
if len(ret)%2!=0:
ret = '0'+ret
return binascii.unhexlify(ret)
def getLSBit(ciphertext):
a0=0
a1=0
# just to counteract randomness we sometimes will ask more than once
while abs(a0-a1)<20:
r = remote('perfect-secrecy.ctfcompetition.com',1337)
r.send('\x00')#m0
r.send('\x01')#m1
snd = itob(ciphertext)
# pad ciphertext to 128 length
snd = snd.zfill(128)
r.send(snd)
res = r.readall(timeout=2)
# count 0 and 1
a0+=res.count('\x00')
a1+=res.count('\x01')
#if more 1s than 0s return True(1)
return a1>a0
if __name__ == '__main__':
pubkey = serialization.load_pem_public_key(open('key_pub.pem', 'rb').read(), backend=default_backend())
n = pubkey.public_numbers().n
e = pubkey.public_numbers().e
with open('flag.txt','rb') as f:
c = int(f.read().encode('hex'),16)
LB = 0
UB = n
p2 = pow(2,e,n)
context(log_level='error')
while UB-1>LB:
M = (UB+LB)/2
c = p2*c %n
print('')
print('%d'%(UB-LB).bit_length())
print('LB: %s'%repr(itob(LB)))
print('RB: %s'%repr(itob(UB)))
if getLSBit(c):
LB=M
else:
UB=M
print('-'*20)
print('Finished')
print('LB: %s'%repr(itob(LB)))
print('RB: %s'%repr(itob(UB)))
```