이번에는 시험기간 중에 진행된 CodeGate에 참여했는데, 생각보다 문제가 많이 어려웠다.. 작년 코드게이트 풀이를 보면 그렇게 어렵다고 느끼진 못했는데,, 역시 아직 갈 길이 멀은 것 같다. 그래도 이상한 브포로 한문제를 풀긴 했는데, 제발 본선에 들어갔으면 좋겠다 ㅎㅎ (대회가 끝나지 않은 시점에서 쓰는 글 >< )
이 문제 말고도 my file encryptor 이라는 문제와 anti kerckhoffs 라는 문제가 있었는데, 이 두 문제는 도저히 풀지 못할 것 같아서,, 미리 롸업이라도 쓰려고 쓰는 글이다. ^^
🖲️ Code Analysis
from Crypto.Util.number import *
from hashlib import sha256
import os
import signal
BITS = 512
def POW():
b = os.urandom(32)
print(f"b = ??????{b.hex()[6:]}")
print(f"SHA256(b) = {sha256(b).hexdigest()}")
prefix = input("prefix > ")
b_ = bytes.fromhex(prefix + b.hex()[6:])
return sha256(b_).digest() == sha256(b).digest()
def generate_server_key():
while True:
p = getPrime(1024)
q = getPrime(1024)
e = 0x10001
if (p-1) % e == 0 or (q-1) % e == 0:
continue
d = pow(e, -1, (p-1)*(q-1))
n = p*q
return e, d, n
# Generate N = (p1+p2) * (q1+q2) where p2 and q2 are shares chosen by the client
def generate_shared_modulus():
SMALL_PRIMES = [2, 3, 5, 7, 11, 13]
print(f"{SERVER_N = }")
print(f"{SERVER_E = }")
p1_remainder_candidates = {}
q1_remainder_candidates = {}
# We prevent p1+p2 is divided by small primes
# by asking the client for a possible remainders of p1
for prime in SMALL_PRIMES:
remainder_candidates = set(map(int, input(f"Candidates of p1 % {prime} > ").split()))
assert len(remainder_candidates) == (prime+1) // 2, f"[-] wrong candidates for {prime}"
p1_remainder_candidates[prime] = remainder_candidates
while True:
p1 = bytes_to_long(os.urandom(BITS // 8))
for prime in SMALL_PRIMES:
if p1 % prime not in p1_remainder_candidates[prime]:
break
else:
break
# and same goes for q1
for prime in SMALL_PRIMES:
remainder_candidates = set(map(int, input(f"Candidates of q1 % {prime} > ").split()))
assert len(remainder_candidates) == (prime+1) // 2, f"[-] wrong candidates for {prime}"
q1_remainder_candidates[prime] = remainder_candidates
while True:
q1 = bytes_to_long(os.urandom(BITS // 8))
for prime in SMALL_PRIMES:
if q1 % prime not in q1_remainder_candidates[prime]:
break
else:
break
p1_enc = pow(p1, SERVER_E, SERVER_N)
q1_enc = pow(q1, SERVER_E, SERVER_N)
print(f"{p1_enc = }")
print(f"{q1_enc = }")
X = list(map(int, input("X > ").split()))
assert len(X) == 12
N = (p1*q1 + sum(pow(x, SERVER_D, SERVER_N) for x in X)) % SERVER_N
assert N.bit_length() >= 1024, f"[-] too short.., {N.bit_length()}"
print(f"{N = }")
return p1, q1, N
# check whether N is a product of two primes
def N_validity_check(p1, q1, N):
for _ in range(20):
b = bytes_to_long(os.urandom(2 * BITS // 8))
print(f"{b = }")
# print(pow(b, N+1-p1-q1, N))
client_digest = input("Client digest > ")
server_digest = sha256(long_to_bytes(pow(b, N+1-p1-q1, N))).hexdigest()
if server_digest != client_digest:
print("N is not a product of two primes I guess..")
return False
else:
print("good!")
return True
if not POW():
exit(-1)
signal.alarm(60)
SERVER_E, SERVER_D, SERVER_N = generate_server_key()
p1, q1, N = generate_shared_modulus()
if not N_validity_check(p1, q1, N):
exit(-1)
FLAG = open("flag", 'rb').read()
FLAG += b'\x00' + os.urandom(128 - 2 - len(FLAG))
FLAG_ENC = pow(bytes_to_long(FLAG), 0x10001, N)
print(f"{FLAG_ENC = }")
역시 코드게이트인 만큼 만만치 않은 난이도였다.. 그래도 비교적 쉬운 코드이니 이해해보자.
1. small prime [2, 3, 5, 7, 11, 13] 들에 대해서 후보 나머지를 받고, 이를 만족하는 512 비트 p, q 생성한다.
2. 이를 통해 생성된 p, q 를 이용해서 p^e (mod n), q^e (mod n) 의 값을 계산해서 알려준다.
3. 추가로 x 를 12 개 입력받아 입력한 x 들을 이용해서 n = p * q + sum(x^d) for x in X 를 계산한다.
4. 그리고 n 과 SERVER_E, SERVER_N 의 값을 알려준다.
5. 그 후에 총 20번 반복하는데, b 를 랜덤 생성 후 알려주고 b^(n-p-q+1) (mod n) 의 값을 입력받는다.
6. 계산한 값과 우리가 입력한 값이 일치하면 넘어간다.
5. 20번의 모든 for 문을 통과하면 flag^e (mod n) 알려준다.
따라서 관건은 어떻게 for 문을 통과할지, 또 마지막에 flag^e 를 이용해서 어떻게 flag 의 값을 역연산할 수 있는지가 되겠다.
💡 Main Idea
사실 문제에는 client_example 이라는 python3 파일이 하나 더 존재했다. 보니까 문제를 풀 수 있는 기본적인 틀은 제공해주는 것 같다 ㅎㅎㅎㅎ 감사합니다 :)
생각했던 것도 풀이에 나와있긴 했지만, 첫 번째 생각해야 할 점은 $ n = p \times q + \sum x^d $ 부분이다. 이 x^d 들을 단순히 더한다고 하면 각 항들이 합쳐지지 않고 흩어져 곱하기 연산으로 나타내는 것을 좋아하는 rsa 관점에서 보기에는 살짝 어려울 것 같다.
따라서 x^d 들을 모두 더하더라도 SERVER_N 에 의한 나머지를 구할 때 모두 나누어 떨어지거나, 우리가 원하는 값들로만 나오도록 설정해주는 것이 중요하다고 생각했다. 우리는 모르지만, SERVER_D 에 의해 복호화 되는 것이기 때문에 SERVER_E 값으로 암호화한 값을 넣어주면, 평문이 나올 것 같다는 생각을 했다.
$$ plaintext^e = ciphertext \text{ (mod n)} \rightarrow ciphertext^d = plaintext \text{ (mod n)}$$
그 후에는, $b^{N - p_1 - q_1 + 1}$ 이라는 값을 예측해야 했는데, 앞서 만들어낸 n 의 $\phi(n)$ 의 값을 이용해서 문제를 해결해야 했다. 여기서 phi(n) 값을 하나로 정해버리기에는 무리가 있을 것 같아서 정확한 $N - p_1 - q_1 + 1$ 의 값을 구하고자 했다.
그렇게 되면 $p_1 + q_1$ 의 값을 구하는 문제로 뒤바뀌어 버린다. 정해져 있는 $p_1 \times q_1$ 의 값을 포함한 채로 $p_1 + q_1$ 의 값을 구하려면, 제곱을 사용하는 방법밖에 없었다.
따라서 $(p_1 + q_1)^2$ 를 사용하면, 문제를 해결해 $p_1 + q_1$ 을 역연산할 수 있을 것 같다. 여기서 중요한 것은, $p_1, q_1$ 값은 모르더라도, $p_1^e, q_1^e$ 의 값은 주어졌기 때문에 n 을 조작하기에는 무리가 없다는 것이다.
마지막으로 $flag^e$ 를 역연산하는 과정에서는, N 이 소수가 되어버리면 쉽게 복구가 가능했다. 따라서 N 이 소수가 되기 위해서 홀수로 만들어주고, 소수가 되기를 기도하며 브루트 포싱을 진행하면 되지 않을까..??
📖 Exploit Code
from Crypto.Util.number import *
from hashlib import sha256
from pwn import *
from itertools import product
import random
from gmpy2 import iroot
# from sage.all import *
context.log_level = "debug"
def get_additive_shares(x, n, mod):
shares = [0] * n
shares[n-1] = x
for i in range(n-1):
shares[i] = random.randrange(mod)
shares[n-1] = (shares[n-1] - shares[i]) % mod
assert sum(shares) % mod == x
return shares
BITS = 512
def POW():
print("[DEBUG] POW...")
b_postfix = r.recvline().decode().split(' = ')[1][6:].strip()
h = r.recvline().decode().split(' = ')[1].strip()
for brute in product('0123456789abcdef', repeat=6):
b_prefix = ''.join(brute)
b_ = b_prefix + b_postfix
if sha256(bytes.fromhex(b_)).hexdigest() == h:
r.sendlineafter(b' > ', b_prefix.encode())
return True
assert 0, "Something went wrong.."
def generate_shared_modulus():
SMALL_PRIMES = [2, 3, 5, 7, 11, 13]
print("[DEBUG] generate_shared_modulus...")
p2 = random.randrange(2 ** BITS, 2 ** (BITS+1))
q2 = random.randrange(2 ** BITS, 2 ** (BITS+1))
base = eval('*'.join([str(n) for n in SMALL_PRIMES]))
p2, q2 = (p2//base)*base, (q2//base)*base
# Candidates of p1
for prime in SMALL_PRIMES:
remainder_candidates = []
# c = (-p2 % prime) should not be chosen
while len(remainder_candidates) < (prime+1) // 2:
c = random.randrange(prime)
if c == -p2 % prime or c in remainder_candidates:
continue
remainder_candidates.append(c)
r.sendlineafter(b' > ', ' '.join(str(c) for c in remainder_candidates).encode())
# Candidates of q1
for prime in SMALL_PRIMES:
remainder_candidates = []
# c = (-q2 % prime) should not be chosen
while len(remainder_candidates) < (prime+1) // 2:
c = random.randrange(prime)
if c == -q2 % prime or c in remainder_candidates:
continue
remainder_candidates.append(c)
r.sendlineafter(b' > ', ' '.join(str(c) for c in remainder_candidates).encode())
p1_enc = int(r.recvline().decode().split(' = ')[1])
q1_enc = int(r.recvline().decode().split(' = ')[1])
p2_enc = pow(p2, SERVER_E, SERVER_N)
q2_enc = pow(q2, SERVER_E, SERVER_N)
X = []
shares_a = get_additive_shares(4, 4, SERVER_N)
shares_b = get_additive_shares(7, 4, SERVER_N)
shares_c = get_additive_shares(4, 3, SERVER_N)
shares_d = get_additive_shares(1, 1, SERVER_N)
# N = p1*q1 + sum(pow(x, SERVER_D, SERVER_N) for x in X) = p1*q1 + p1*q2 + p2*q1 * p2*q2 = (p1+p2)*(q1+q2)
for i in range(4):
X.append(pow(shares_a[i], SERVER_E, SERVER_N) * p1_enc * p1_enc % SERVER_N)
X.append(pow(shares_b[i], SERVER_E, SERVER_N) * p1_enc * q1_enc % SERVER_N)
for i in range(3):
X.append(pow(shares_c[i], SERVER_E, SERVER_N) * q1_enc * q1_enc % SERVER_N)
X.append(1)
random.shuffle(X)
r.sendlineafter(b' > ', ' '.join(str(x) for x in X).encode())
N = int(r.recvline().decode().split(' = ')[1])
return p2, q2, N
# STEP 2 - N_validity_check
def N_validity_check_client(p2, q2, N):
print("[DEBUG] N_validity_check_client...")
for _ in range(20):
b = int(r.recvline().decode().split(' = ')[1])
client_digest = sha256(long_to_bytes(pow(b, N - p_q + 1, N))).hexdigest()
r.sendlineafter(b' > ', client_digest.encode())
msg = r.recvline().decode()
if msg != "good!\n":
print(msg)
return -1
flag_enc = int(r.recvline().decode().split(' = ')[1])
return flag_enc
while True:
try:
REMOTE = sys.argv[1]
except:
REMOTE = None
if REMOTE:
r = remote("13.125.181.74", 9001)
else:
r = process(["python3", "./prob.py"])
POW()
SERVER_N = int(r.recvline().decode().split(' = ')[1])
SERVER_E = int(r.recvline().decode().split(' = ')[1])
# print(SERVER_N, SERVER_E)
p2, q2, N = generate_shared_modulus()
assert iroot(N-1, 2)[1] == True
p_q = int(iroot(N-1, 2)[0])//2
# pause()
flag_enc = N_validity_check_client(p2, q2, N)
# we know pow(p1, e, SERVER_N)
print(N)
if not isPrime(N):
print("No Prime!!")
r.close()
continue
if flag_enc == -1:
exit(-1)
print(f"{N = }")
print(f"{flag_enc = }")
d = inverse(0x10001, N-1)
flag = pow(flag_enc, d, N)
flag = long_to_bytes(flag)
print(f"{flag = }")
break
브포하는데에 시간이 조금 걸리긴 했는데,, 그래도 문제를 해결할 수 있었다. 추가로 n 이 1024 비트 이상이 되도록 하는 코드가 존재해서 마지막으로 $n = 2^2 \times (p_1 + q_1)^2 + 1$ 이 되도록 설정해주어 n 이 소수가 되고, $p_1 + q_1$ 을 역연산할 수 있었다.
(사실 브포가 intened solution 이라고 생각하지는 않는다.. 아마 풀이 올라오면 더 기깔난 풀이가 있지 않을까 생각한다. )
'Cryptography > CTF' 카테고리의 다른 글
[Zer0pts 2023] - easy factoring (0) | 2023.07.21 |
---|---|
[Zer0pts 2023] - squarerng (0) | 2023.07.21 |
[CCE 2023] - the miracle (0) | 2023.06.12 |
[Kaist-Postech 2020] - fixed point revenge (0) | 2023.04.19 |
[Kaist-Postech 2020] - Baby Bubmi (0) | 2023.04.18 |