from sage.all import *
import argparse

# Håstad's broadcast attack (identical-message case).
# The same message m is encrypted with the same small exponent e under k >= e
# different moduli. CRT recovers m^e over the integers, then take the e-th root.
# https://en.wikipedia.org/wiki/Coppersmith%27s_attack#H%C3%A5stad's_broadcast_attack


def hastad_broadcast(e, cts, ns):
    """
    Recover m from k >= e ciphertexts of the same message under different moduli.

    Args:
        e (int): The shared public exponent
        cts (list[int]): The ciphertexts c_i = m^e mod n_i
        ns (list[int]): The moduli n_i (must be pairwise coprime)

    Returns:
        int: The recovered message m, or None if no integer e-th root exists
    """
    if len(cts) < e:
        print("Warning: need at least e={} ciphertexts, got {}".format(e, len(cts)))

    me = crt([Integer(c) for c in cts], [Integer(n) for n in ns])  # m^e over Z
    m, exact = Integer(me).nth_root(e, truncate_mode=True)
    if not exact:
        print("No exact e-th root, moduli may not be coprime or message too long")
        return None
    return int(m)


def test():
    e = 3
    m = Integer(bytes_to_long(b"flag{hastad_broadcast}"))
    ns = [random_prime(2 ** 512) * random_prime(2 ** 512) for _ in range(e)]
    cts = [pow(m, e, n) for n in ns]
    assert hastad_broadcast(e, cts, ns) == m
    print("Success:", long_to_bytes(hastad_broadcast(e, cts, ns)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Håstad's broadcast attack on RSA")
    parser.add_argument("e", type=int, help="Shared public exponent")
    parser.add_argument("-c", type=int, nargs="+", help="Ciphertexts", required=True)
    parser.add_argument("-n", type=int, nargs="+", help="Moduli", required=True)
    args = parser.parse_args()

    print(hastad_broadcast(args.e, args.c, args.n))
