comibear
article thumbnail

이번에는 시험기간 중에 진행된 CodeGate에 참여했는데, 생각보다 문제가 많이 어려웠다.. 작년 코드게이트 풀이를 보면 그렇게 어렵다고 느끼진 못했는데,, 역시 아직 갈 길이 멀은 것 같다. 그래도 이상한 브포로 한문제를 풀긴 했는데, 제발 본선에 들어갔으면 좋겠다 ㅎㅎ (대회가 끝나지 않은 시점에서 쓰는 글 >< )

 

이 문제 말고도 my file encryptor 이라는 문제와 anti kerckhoffs 라는 문제가 있었는데, 이 두 문제는 도저히 풀지 못할 것 같아서,, 미리 롸업이라도 쓰려고 쓰는 글이다. ^^

🖲️ Code Analysis

from Crypto.Util.number import *
from hashlib import sha256
import os
import signal

BITS = 512

def POW():
    b = os.urandom(32)
    print(f"b = ??????{b.hex()[6:]}")
    print(f"SHA256(b) = {sha256(b).hexdigest()}")
    prefix = input("prefix > ")
    b_ = bytes.fromhex(prefix + b.hex()[6:])
    return sha256(b_).digest() == sha256(b).digest()

def generate_server_key():
    while True:
        p = getPrime(1024)
        q = getPrime(1024)
        e = 0x10001
        if (p-1) % e == 0 or (q-1) % e == 0:
            continue
        d = pow(e, -1, (p-1)*(q-1))
        n = p*q
        return e, d, n

# Generate N = (p1+p2) * (q1+q2) where p2 and q2 are shares chosen by the client
def generate_shared_modulus():
    SMALL_PRIMES = [2, 3, 5, 7, 11, 13]
    print(f"{SERVER_N = }")
    print(f"{SERVER_E = }")

    p1_remainder_candidates = {}
    q1_remainder_candidates = {}
    # We prevent p1+p2 is divided by small primes
    # by asking the client for a possible remainders of p1
    for prime in SMALL_PRIMES:
        remainder_candidates = set(map(int, input(f"Candidates of p1 % {prime} > ").split()))
        assert len(remainder_candidates) == (prime+1) // 2, f"[-] wrong candidates for {prime}"
        p1_remainder_candidates[prime] = remainder_candidates

    while True:
        p1 = bytes_to_long(os.urandom(BITS // 8))
        for prime in SMALL_PRIMES:
            if p1 % prime not in p1_remainder_candidates[prime]:
                break
        else:
            break

    # and same goes for q1
    for prime in SMALL_PRIMES:
        remainder_candidates = set(map(int, input(f"Candidates of q1 % {prime} > ").split()))
        assert len(remainder_candidates) == (prime+1) // 2, f"[-] wrong candidates for {prime}"
        q1_remainder_candidates[prime] = remainder_candidates

    while True:
        q1 = bytes_to_long(os.urandom(BITS // 8))

        for prime in SMALL_PRIMES:
            if q1 % prime not in q1_remainder_candidates[prime]:
                break
        else:
            break

    p1_enc = pow(p1, SERVER_E, SERVER_N)
    q1_enc = pow(q1, SERVER_E, SERVER_N)

    print(f"{p1_enc = }")
    print(f"{q1_enc = }")
    X = list(map(int, input("X > ").split()))
    assert len(X) == 12

    N = (p1*q1 + sum(pow(x, SERVER_D, SERVER_N) for x in X)) % SERVER_N
    assert N.bit_length() >= 1024, f"[-] too short.., {N.bit_length()}"

    print(f"{N = }")

    return p1, q1, N

# check whether N is a product of two primes
def N_validity_check(p1, q1, N):
    for _ in range(20):
        b = bytes_to_long(os.urandom(2 * BITS // 8))
        print(f"{b = }")
        # print(pow(b, N+1-p1-q1, N))
        client_digest = input("Client digest > ")
        server_digest = sha256(long_to_bytes(pow(b, N+1-p1-q1, N))).hexdigest()
        if server_digest != client_digest:
            print("N is not a product of two primes I guess..")
            return False
        else:
            print("good!")

    return True

if not POW():
    exit(-1)

signal.alarm(60)
SERVER_E, SERVER_D, SERVER_N = generate_server_key()
p1, q1, N = generate_shared_modulus()
if not N_validity_check(p1, q1, N):
    exit(-1)

FLAG = open("flag", 'rb').read()
FLAG += b'\x00' + os.urandom(128 - 2 - len(FLAG))
FLAG_ENC = pow(bytes_to_long(FLAG), 0x10001, N)

print(f"{FLAG_ENC = }")

역시 코드게이트인 만큼 만만치 않은 난이도였다.. 그래도 비교적 쉬운 코드이니 이해해보자. 


1. small prime [2, 3, 5, 7, 11, 13] 들에 대해서 후보 나머지를 받고, 이를 만족하는 512 비트 p, q 생성한다.
2. 이를 통해 생성된 p, q 를 이용해서 p^e (mod n), q^e (mod n) 의 값을 계산해서 알려준다. 
3. 추가로 x 를 12 개 입력받아 입력한 x 들을 이용해서 n = p * q + sum(x^d) for x in X 를 계산한다. 
4. 그리고 n 과 SERVER_E, SERVER_N 의 값을 알려준다. 
5. 그 후에 총 20번 반복하는데,  b 를 랜덤 생성 후 알려주고 b^(n-p-q+1) (mod n) 의 값을 입력받는다. 
6. 계산한 값과 우리가 입력한 값이 일치하면 넘어간다. 
5. 20번의 모든 for 문을 통과하면 flag^e (mod n) 알려준다. 

따라서 관건은 어떻게 for 문을 통과할지, 또 마지막에 flag^e 를 이용해서 어떻게 flag 의 값을 역연산할 수 있는지가 되겠다.

💡 Main Idea

사실 문제에는 client_example 이라는 python3 파일이 하나 더 존재했다. 보니까 문제를 풀 수 있는 기본적인 틀은 제공해주는 것 같다 ㅎㅎㅎㅎ 감사합니다 :) 

 

생각했던 것도 풀이에 나와있긴 했지만, 첫 번째 생각해야 할 점은 $ n = p \times q + \sum x^d $ 부분이다. 이 x^d 들을 단순히 더한다고 하면 각 항들이 합쳐지지 않고 흩어져 곱하기 연산으로 나타내는 것을 좋아하는 rsa 관점에서 보기에는 살짝 어려울 것 같다. 

 

따라서 x^d 들을 모두 더하더라도 SERVER_N 에 의한 나머지를 구할 때 모두 나누어 떨어지거나, 우리가 원하는 값들로만 나오도록 설정해주는 것이 중요하다고 생각했다. 우리는 모르지만, SERVER_D 에 의해 복호화 되는 것이기 때문에 SERVER_E 값으로 암호화한 값을 넣어주면, 평문이 나올 것 같다는 생각을 했다. 

 

$$ plaintext^e = ciphertext \text{ (mod n)} \rightarrow ciphertext^d  = plaintext \text{ (mod n)}$$

 

그 후에는, $b^{N - p_1 - q_1 + 1}$ 이라는 값을 예측해야 했는데, 앞서 만들어낸 n 의 $\phi(n)$ 의 값을 이용해서 문제를 해결해야 했다. 여기서 phi(n) 값을 하나로 정해버리기에는 무리가 있을 것 같아서 정확한 $N - p_1 - q_1 + 1$ 의 값을 구하고자 했다. 

 

그렇게 되면 $p_1 + q_1$ 의 값을 구하는 문제로 뒤바뀌어 버린다. 정해져 있는 $p_1 \times q_1$ 의 값을 포함한 채로 $p_1 + q_1$ 의 값을 구하려면, 제곱을 사용하는 방법밖에 없었다. 

 

따라서 $(p_1 + q_1)^2$ 를 사용하면, 문제를 해결해 $p_1 + q_1$ 을 역연산할 수 있을 것 같다. 여기서 중요한 것은, $p_1, q_1$ 값은 모르더라도, $p_1^e, q_1^e$ 의 값은 주어졌기 때문에 n 을 조작하기에는 무리가 없다는 것이다. 

 

마지막으로 $flag^e$ 를 역연산하는 과정에서는, N 이 소수가 되어버리면 쉽게 복구가 가능했다. 따라서 N 이 소수가 되기 위해서 홀수로 만들어주고, 소수가 되기를 기도하며 브루트 포싱을 진행하면 되지 않을까..??

 

📖 Exploit Code

from Crypto.Util.number import *
from hashlib import sha256
from pwn import *
from itertools import product
import random
from gmpy2 import iroot
# from sage.all import *

context.log_level = "debug"

def get_additive_shares(x, n, mod):
    shares = [0] * n
    shares[n-1] = x
    for i in range(n-1):
        shares[i] = random.randrange(mod)
        shares[n-1] = (shares[n-1] - shares[i]) % mod
    assert sum(shares) % mod == x
    return shares

BITS = 512

def POW():
    print("[DEBUG] POW...")
    b_postfix = r.recvline().decode().split(' = ')[1][6:].strip()
    h = r.recvline().decode().split(' = ')[1].strip()
    for brute in product('0123456789abcdef', repeat=6):
        b_prefix = ''.join(brute)
        b_ = b_prefix + b_postfix
        if sha256(bytes.fromhex(b_)).hexdigest() == h:
            r.sendlineafter(b' > ', b_prefix.encode())
            return True

    assert 0, "Something went wrong.."

def generate_shared_modulus():
    SMALL_PRIMES = [2, 3, 5, 7, 11, 13]
    print("[DEBUG] generate_shared_modulus...")
    p2 = random.randrange(2 ** BITS, 2 ** (BITS+1))
    q2 = random.randrange(2 ** BITS, 2 ** (BITS+1))

    base = eval('*'.join([str(n) for n in SMALL_PRIMES]))
    p2, q2 = (p2//base)*base, (q2//base)*base

    # Candidates of p1
    for prime in SMALL_PRIMES:
        remainder_candidates = []
        # c = (-p2 % prime) should not be chosen
        while len(remainder_candidates) < (prime+1) // 2:
            c = random.randrange(prime)
            if c == -p2 % prime or c in remainder_candidates:
                continue
            remainder_candidates.append(c)

        r.sendlineafter(b' > ', ' '.join(str(c) for c in remainder_candidates).encode())

    # Candidates of q1
    for prime in SMALL_PRIMES:
        remainder_candidates = []
        # c = (-q2 % prime) should not be chosen
        while len(remainder_candidates) < (prime+1) // 2:
            c = random.randrange(prime)
            if c == -q2 % prime or c in remainder_candidates:
                continue
            remainder_candidates.append(c)

        r.sendlineafter(b' > ', ' '.join(str(c) for c in remainder_candidates).encode())

    p1_enc = int(r.recvline().decode().split(' = ')[1])
    q1_enc = int(r.recvline().decode().split(' = ')[1])
    p2_enc = pow(p2, SERVER_E, SERVER_N)
    q2_enc = pow(q2, SERVER_E, SERVER_N)

    X = []
    shares_a = get_additive_shares(4, 4, SERVER_N)
    shares_b = get_additive_shares(7, 4, SERVER_N)
    shares_c = get_additive_shares(4, 3, SERVER_N)
    shares_d = get_additive_shares(1, 1, SERVER_N)

    # N = p1*q1 + sum(pow(x, SERVER_D, SERVER_N) for x in X) = p1*q1 + p1*q2 + p2*q1 * p2*q2 = (p1+p2)*(q1+q2)
    for i in range(4):
        X.append(pow(shares_a[i], SERVER_E, SERVER_N) * p1_enc * p1_enc % SERVER_N)
        X.append(pow(shares_b[i], SERVER_E, SERVER_N) * p1_enc * q1_enc % SERVER_N)
    for i in range(3):
        X.append(pow(shares_c[i], SERVER_E, SERVER_N) * q1_enc * q1_enc % SERVER_N)

    X.append(1)
    random.shuffle(X)

    r.sendlineafter(b' > ', ' '.join(str(x) for x in X).encode())

    N = int(r.recvline().decode().split(' = ')[1])

    return p2, q2, N

# STEP 2 - N_validity_check
def N_validity_check_client(p2, q2, N):
    print("[DEBUG] N_validity_check_client...")
    for _ in range(20):
        b = int(r.recvline().decode().split(' = ')[1])
        client_digest = sha256(long_to_bytes(pow(b, N - p_q + 1, N))).hexdigest()
        r.sendlineafter(b' > ', client_digest.encode())
        msg = r.recvline().decode()
        if msg != "good!\n":
            print(msg)
            return -1

    flag_enc = int(r.recvline().decode().split(' = ')[1])
    return flag_enc

while True:
    try:
        REMOTE = sys.argv[1]
    except:
        REMOTE = None

    if REMOTE:
        r = remote("13.125.181.74", 9001)
    else:
        r = process(["python3", "./prob.py"])

    POW()
    SERVER_N = int(r.recvline().decode().split(' = ')[1])
    SERVER_E = int(r.recvline().decode().split(' = ')[1])
# print(SERVER_N, SERVER_E)

    p2, q2, N = generate_shared_modulus()
    assert iroot(N-1, 2)[1] == True
    p_q = int(iroot(N-1, 2)[0])//2

# pause()
    flag_enc = N_validity_check_client(p2, q2, N)
# we know pow(p1, e, SERVER_N)
    print(N)
    if not isPrime(N):
        print("No Prime!!")
        r.close()
        continue

    if flag_enc == -1:
        exit(-1)

    print(f"{N = }")
    print(f"{flag_enc = }")

    d = inverse(0x10001, N-1)
    flag = pow(flag_enc, d, N)
    flag = long_to_bytes(flag)

    print(f"{flag = }")
    break

브포하는데에 시간이 조금 걸리긴 했는데,, 그래도 문제를 해결할 수 있었다. 추가로 n 이 1024 비트 이상이 되도록 하는 코드가 존재해서 마지막으로 $n = 2^2 \times (p_1 + q_1)^2 + 1$ 이 되도록 설정해주어 n 이 소수가 되고, $p_1 + q_1$ 을 역연산할 수 있었다. 

 

(사실 브포가 intened solution 이라고 생각하지는 않는다.. 아마 풀이 올라오면 더 기깔난 풀이가 있지 않을까 생각한다. )

 

'Cryptography > CTF' 카테고리의 다른 글

[Zer0pts 2023] - easy factoring  (0) 2023.07.21
[Zer0pts 2023] - squarerng  (0) 2023.07.21
[CCE 2023] - the miracle  (0) 2023.06.12
[Kaist-Postech 2020] - fixed point revenge  (0) 2023.04.19
[Kaist-Postech 2020] - Baby Bubmi  (0) 2023.04.18
profile

comibear

@comibear

포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!

검색 태그