In [57]:
import operator


def xorbytes(a:bytes, b:bytes) -> bytes:
    """XOR two byte strings together, return the result as a byte string
    If the strings are not the same length, the result will be the length of the shorter string
    """
    l = min(len(a), len(b))
    return bytes([a[i] ^ b[i] for i in range(l)])


class SimpleXORCipher():
    def __init__(self, key) -> None:
        self.key = key

    def encrypt(self, byte_plain_text):
        result = b""
        for i in range(len(byte_plain_text)):
            result += bytes([byte_plain_text[i] ^ self.key[i % len(self.key)]])
        return result

class SimpleXORCracker():

    def __init__(self, allowed_chars:bytes=None, best_chars:bytes=None, max_key_length:int=256, debug:bool=False) -> None:
        if allowed_chars is None:
            allowed_chars = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ,.?!:;'-\""
        if best_chars is None:
            best_chars = b"abcdefghijklmnopqrstuvwxyz "
        if isinstance(allowed_chars, str):
            allowed_chars = allowed_chars.encode("utf-8")
        self.best_chars = best_chars
        self.allowed_chars = allowed_chars
        self.max_key_length = max_key_length
        self.debug = debug
        self.input_type = None

    def crack(self, input, key_length=None):
        bin_string = self.parse_input(input)

        if self.debug:
            print("Input type: %s" % self.input_type)
            print("Input length: %d" % len(bin_string))

        # User defined key length
        if key_length:
            dist = self.compute_avg_nrm_hamming_distance(bin_string, key_length)
            distances = [(key_length, dist)]
            if self.debug:
                print("Using key length of %d (%f)" % (key_length, dist))
            if dist > 0.42:
                print("\nWARNING: Distance is high, this may not be the right key length")
            if dist == 0:
                print("\nWARNING: Distance not found, the key length may be too long")
            
        # Find most likely key length
        else:
            distances = self.rank_key_lengths(bin_string)
            if len(distances) == 0:
                print("No key lengths found (try increasing max_key_length)")
                exit()
            if distances[0][1] > 0.42:
                print("\nWARNING: Distance is high, this may not be a repeating XOR cipher")
                print("Try increasing max_key_length or using a defined key length")
            if self.debug:
                print("\nBest ranking key lengths: (lower distance is better)")
                for d in distances[:10]:
                    print("%d bits (%.2fB) : %f" % (d[0], d[0]/8, d[1]))

        print("\nTrying to find keys that produce readable plaintext (only multiples of 8 bits)")
        byte_string = int(bin_string, 2).to_bytes((len(bin_string) + 7) // 8, byteorder='big')
        results = []
        for key_length, distance in distances[:5]:
            if key_length % 8 != 0:
                continue

            # Group the cipher text into blocks of key length
            chunks = self.make_chunks(byte_string, key_length // 8)

            # Find the most likely key, and keep only the ones with the best score
            keys = self.find_keys(chunks, len(chunks) // 10)
            keys = [key for key in keys if key[0] == keys[0][0]]
            
            results.append(keys)

            if self.debug:
                print("\nKey length: %d Bytes (%f)" % (key_length // 8, distance))
                for key in keys:
                    print("Key:(%d) %s" % (key[0], key[1]))
                    print("--> %s" % b"".join([xorbytes(key[1], c) for c in chunks[:10]]))

        return results

    def parse_input(self, input):
        """Check the type of the input and return it in binary string format"""

        # Python string
        if isinstance(input, str):

            # Already in binary
            if all(c in "01" for c in input):
                self.input_type = "binary string"
                return input
            
            # Hex
            if all(c in "0123456789abcdef" for c in input.lower()):
                self.input_type = "hex string"
                return bin(int(input, 16))[2:]
            
            # Base64
            if all(c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" for c in input):
                self.input_type = "base64 string"
                return bin(int(input, 64))[2:]
            
            # Regular string
            self.input_type = "Regular string (you should probably use byte string instead)"
            return bin(int(input.encode("hex"), 16))[2:]
        
        # Python bytes
        if isinstance(input, bytes):
            self.input_type = "byte string"
            return bin(int(input.hex(), 16))[2:]
    

    def rank_key_lengths(self, cipher_text):
        """Find the most likely key lengths by computing the normalised average hamming distance"""
        scores = {} # key = key length, value = normalised average hamming distance


        # For each key length, compute the average hamming distance between adjacent blocks of that length
        for key_length in range(2, self.max_key_length):
            dist = self.compute_avg_nrm_hamming_distance(cipher_text, key_length)
            if dist > 0:
                scores[key_length] = dist

        # key length with lowest normalised average hamming distance is most likely to be the correct one
        return sorted(scores.items(), key=operator.itemgetter(1), reverse=False)


    def compute_avg_nrm_hamming_distance(self, cipher_text, key_length):
        """Compute the normalised average hamming distance between all adjacent blocks of key_length"""
        hamming_distances = 0
        position = 0
        block_count = 0
        while(position + 2 * key_length <= len(cipher_text)):
            block1 = cipher_text[position:position+key_length]
            position += key_length
            block2 = cipher_text[position:position+key_length]
            hamming_distances += self.hamming_distance(block1, block2)
            block_count += 1
        if block_count > 0:
            # Average
            average_hamming_distance = hamming_distances / float(block_count)
            # Normalise
            return average_hamming_distance / float(key_length)
        return 0

    def hamming_distance(self, string_one:str, string_two:str) -> int:
        """Compute the Hamming distance between two strings of equal length
        https://en.wikipedia.org/wiki/Hamming_distance
        """
        assert len(string_one) == len(string_two)
        count = 0
        for i in range(len(string_one)):
            if string_one[i] != string_two[i]:
                count += 1
        return count

    def make_chunks(self, byte_string:bytes, chunksize:int) -> list[bytes]:
        """Split the byte string into chunks of size chunksize"""
        chunks = [byte_string[i:i+chunksize] for i in range(0, len(byte_string), chunksize)]
        return chunks
    
    def group_by_pos(self, chunks:list[bytes]) -> list[bytes]:
        """Group the bytes by position in the chunks"""
        groups = []
        max_len = max([len(c) for c in chunks])
        for i in range(max_len):
            group = []
            for c in chunks:
                if len(c) > i:
                    group.append(c[i])
            groups.append(bytes(group))
        return groups
    
    def score(self, byte_string:bytes) -> int:
        """Score the byte string based on the number of allowed characters"""
        return sum([1 for c in byte_string if c in self.allowed_chars]) + sum([1 for c in byte_string if c in self.best_chars]) * 2
    
    def try_xor(self, group:bytes, byte_key:int) -> bytes:
        """Try to xor the group of bytes with the byte key"""
        new_group = b''
        for c in group:
            new_group += (c ^ byte_key).to_bytes(1, 'big')
        return new_group


    def find_keys(self, chunks:list[bytes], epsilon=1) -> list:
        groups = self.group_by_pos(chunks)
        l = len(chunks[0])
        output = []

        # For each byte position
        for k in range(l):
            results = []
            for i in range(256):
                xord = self.try_xor(groups[k], i)
                results.append((self.score(xord), i))

            # Sort by score
            results.sort(key=lambda x: x[0], reverse=True)

            # Only keep the results with score higher than l - epsilon, or the first result
            selected = [r for r in results if r[0] >= results[0][0] - epsilon][:20]
            output.append(selected)
        
        # Compute the keys with the best scores
        keys = [(0, b'')]
        # For each byte position
        for o in output:
            new_keys = []
            # For each key
            for key in keys:
                for c in o:
                    new_keys.append((key[0] + c[0], key[1] + c[1].to_bytes(1, 'big')))
            keys = new_keys
            keys.sort(key=lambda x: x[0], reverse=True)
            keys = keys[:1000]
        return keys

In [58]:
xcipher = SimpleXORCipher(b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f')
ct = xcipher.encrypt(b"In cryptography, the one-time pad (OTP) is an encryption technique that cannot be cracked, but requires the use of a single-use pre-shared key that is not smaller than the message being sent. In this technique, a plaintext is paired with a random secret key (also referred to as a one-time pad). Then, each bit or character of the plaintext is encrypted by combining it with the corresponding bit or character from the pad using modular addition.")
ct

b'Io"`v|vsgnxj|ew# ujf$jhb%}cfi-~nd!*LPU/\'az*jb-kacs{splii(}ohdcg~ud"wldr\'khdecy.me!aqefmbl%*iyy.}epwjv`u\'|ao+y~k/og"b$voioeo&y~k/psg.wmgumm*`it.{h`v#mv&ig}*xalbces"wldh\'|ao+ah}|afg#f`oio)ynby /Io"wllu\'|licbd\x7fze-"b$ujfag~nty.fs!rbmwcc(~c\x7fd-o/r`lgkh&tmjxnx-ejy!*bhvi\'zlln~\x7fkk um#ev&f(fdn!ygbe!rb`,(\'\\aoe -knci"amq&hz)icm\x7foltdp#kc&s`l*{`lgatdzw$lu\'mgiyu}zjd!`z$fijj`dbbj.ft!ujpm&s`l*hc\x7f|jsqmm`lh`(kc\x7f,b|/cicqefrbz)lyc`.{hd"sea&r{`dl,`akumcq$dbca}cdb#'

In [59]:
xor_cipher_cracker = SimpleXORCracker(debug=True)
possible_decryptions = xor_cipher_cracker.crack(ct)

Input type: byte string
Input length: 3567

Best ranking key lengths: (lower distance is better)
128 bits (16.00B) : 0.326923
160 bits (20.00B) : 0.328274
104 bits (13.00B) : 0.331294
32 bits (4.00B) : 0.332670
248 bits (31.00B) : 0.333437
8 bits (1.00B) : 0.336149
72 bits (9.00B) : 0.336516
120 bits (15.00B) : 0.337500
96 bits (12.00B) : 0.338252
224 bits (28.00B) : 0.338329

Trying to find keys that produce readable plaintext (only multiples of 8 bits)

Key length: 16 Bytes (0.326923)
Key:(1292) b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
--> b'In cryptography, the one-time pad (OTP) is an encryption technique that cannot be cracked, but requires the use of a single-use pre-shared key that is not small'

Key length: 20 Bytes (0.328274)
Key:(1108) b'\x04\r\x02\x07\x01\r\x06\x03\x0e\x14\n\x0f\x08\x01\x0e\x10\x04\x14\r\x07'
--> b'Mb gwqppizretdy3$aga gje$peeg9tal $\\TA" ew(mc mbmgq|xmgy,ibo`neyti$tbpx(cijugm#ja,cvdkkab1 fqx madzmrmw }li(wja gf,r bbnkhm!xsm,~gm!\x7fliei

In [60]:
xor_cipher_cracker = SimpleXORCracker(debug=True)
possible_decryptions = xor_cipher_cracker.crack(ct, 8*16)

Input type: byte string
Input length: 3567
Using key length of 128 (0.326923)

Trying to find keys that produce readable plaintext (only multiples of 8 bits)

Key length: 16 Bytes (0.326923)
Key:(1292) b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
--> b'In cryptography, the one-time pad (OTP) is an encryption technique that cannot be cracked, but requires the use of a single-use pre-shared key that is not small'
