Midnight Sun CTF 2019 Quals Writeup

この大会は2019/4/6 1:00(JST)~2019/4/7 1:00(JST)に開催されました。
今回もチームで参戦。結果は608点で432チーム中70位でした。
自分で解けた問題をWriteupとして書いておきます。

Ezdsa (crypto)

$ nc ezdsa-01.play.midnightsunctf.se 31337
Welcome to Spooners' EZDSA
Options:
1. Sign protocol
2. Quit
1
Enter data:
YQo=
(197331920949914888652115219109088718770931603447L, 512962212774096285728517071662679127614634451886L)
Options:
1. Sign protocol
2. Quit

スクリプトの内容は以下のような感じ。

m: 入力
h: mのsha1の数値
u: ランダム20*8ビット値

k = pow(g, u*m, q)
r = pow(g, k, p) % q
s = pow(k, q - 2, q) * (h + flag * r) % q

以下の定理が成り立つ。

pow(k, q - 1, q) = 1 mod q

上記の定理から以下が成り立つ。

pow(k, q - 2, q) = inv(k, q)

このことからこの問題は一般的なDSAの問題。この問題の一番多いパターンは同じrで異なるsのペアを見つけると脆弱性が発生するということ。何とかしてこのデータを入手したい。
0やq-1の倍数は指定できないので、少し調整して、(q-1)/2とq*(q-1)/2の場合で試すと、rが同じになった。あとはkを算出した後、key(flag)を算出すればよい。

import socket
from base64 import b64encode
from hashlib import sha1
from Crypto.Util.number import *

def recvuntil(s, tail):
    data = ''
    while True:
        if tail in data:
            return data
        data += s.recv(1)

q = 0x926c99d24bd4d5b47adb75bd9933de8be5932f4bL

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(('ezdsa-01.play.midnightsunctf.se', 31337))

#### 1st try ####
data = recvuntil(s, 'Quit\n')
print data + '1'
s.sendall('1\n')
send_str1 = long_to_bytes((q-1)/2)
send_data1 = b64encode(send_str1)
data = recvuntil(s, 'data:\n')
print data + send_data1
s.sendall(send_data1 + '\n')
data = recvuntil(s, '\n').strip()
print data
r1 = int(data[1:-1].split(', ')[0].rstrip('L'))
s1 = int(data[1:-1].split(', ')[1].rstrip('L'))

#### 2nd try ####
data = recvuntil(s, 'Quit\n')
print data + '1'
s.sendall('1\n')
send_str2 = long_to_bytes(q*(q-1)/2)
send_data2 = b64encode(send_str2)
data = recvuntil(s, 'data:\n')
print data + send_data2
s.sendall(send_data2 + '\n')
data = recvuntil(s, '\n').strip()
print data
r2 = int(data[1:-1].split(', ')[0].rstrip('L'))
s2 = int(data[1:-1].split(', ')[1].rstrip('L'))

data = recvuntil(s, 'Quit\n')
print data + '2'
s.sendall('2\n')
data = recvuntil(s, '\n').strip()
print data

assert r1 == r2
assert s1 != s2

#### calculate k ####
h1 = bytes_to_long(sha1(send_str1).digest())
h2 = bytes_to_long(sha1(send_str2).digest())
k = int(((h1 - h2) % q) * inverse(((s1 - s2) % q), q))

#### calculate secret (flag) ####
key = int((((((s1 * k) % q) - h1) % q) * inverse(r1, q)) % q)
flag = long_to_bytes(key)
print flag
th4t_w4s_e4sy_eh?