Linear Equation 3

Linear Equation 3

Category: Reverse

Im just dumping my solve script here. I might add a writeup later.

Please don’t judge my code too harshly, the goal was to solve quickly, not to write clean code.

Solution

Utility

import ctypes
import mmap
import tqdm
import z3

filename = "lineq3"

with open(filename, "rb") as f:
    data = f.read()

GHIDRA_OFFSET = 0x400000

def get_addr(address):
    return address-GHIDRA_OFFSET

def get_bytes(address, length=1):
    return data[get_addr(address):get_addr(address)+length]


def execute_at(addr, param):

    code = get_bytes(addr, 0x80)
    buf = mmap.mmap(-1, mmap.PAGESIZE, prot=mmap.PROT_READ | mmap.PROT_WRITE | mmap.PROT_EXEC)

    ftype = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int)
    fpointer = ctypes.c_void_p.from_buffer(buf)

    f = ftype(ctypes.addressof(fpointer))

    buf.write(code)

    r = f(param)

    del fpointer
    buf.close()
    return (r)

Find the prototype of each function

L.py contains the strings of the functions, used to find their prototype.

from L import L

"""
L = [
"result = result & uVar4",
"bVar1 = ((tables[\"31fc1010\"])[input_e8[4]])",
"uVar2 = ((tables[\"320c1810\"])[bVar1])",
"uVar3 = tables[\"2b770150\"](uVar2)",
"bVar1 = ((tables[\"31fc1010\"])[input_118[15]])",
...
"""

# Parse L and extract the number of nested functions. The key is the address in tables
nested_functions = {}
for l in L:
    if "tables" in l:
        addr = l.split("[")[1].split("]")[0].strip("\"")
        
        # Count the number of [ and ( to determine the number of nested functions
        brackets = l.count("[")
        parentheses = l.count("(")
        # If "input" is in the string, -1 bracket
        if "input" in l:
            brackets -= 1
        if "uVar2" in l and parentheses==2:
            parentheses = 0
        nested_functions[addr] = parentheses

Functions and tables generation

def generate_table_simple(addr): # bVar1 = ((tables["2ca16010"])[input_108[13]])
    function_addrs = [get_bytes(addr+i*0x8, 0x8)[::-1] for i in range(0x100)]
    
    table = []
    for ival in function_addrs:
        value = execute_at(int(ival.hex(), 16), 0x10)
        table.append(value)
        if value > 0x100:
            print(f"WARNING: Simple function at {hex(int(ival.hex(), 16))} returned {value}")
    return table

def generate_table_double(addr): # uVar2 = ((tables["2cb16810"])[bVar1])
    function_addrs = [get_bytes(addr+i*0x8, 0x8)[::-1] for i in range(0x100)]

    table = []
    for ival in function_addrs:
        value = execute_at(int(ival.hex(), 16), 0x10)
        table.append(value)
        if value > 0x10000:
            print(f"WARNING: Double function at {hex(int(ival.hex(), 16))} returned {value}")
    if all([x < 0x100 for x in table]):
        print("WARNING: Double function seems to be simple")
    
    for i in range(0x100):
        is_mul = True
        for j in range(0x100):
            if table[j] != (i*j) % 0x10000:
                is_mul = False
        if is_mul:
            return lambda x: x*z3.BitVecVal(i, 16)
    print("Unknown function at address: ", hex(addr))
    return None


def genetate_function_simple(addr): # uVar3 = tables["00d35150"][uVar2]
    """
    Try a pattern on a few values, then returns a 1 input function like:
    lambda x + i; lambda x - i; lambda x * i; lambda -x + i; lambda -x - i etc. where i is a constant to be found
    """
    x = [i for i in range(0x100)]
    y = [execute_at(addr, i) % 0x10000 for i in x]
    
    for i in range(0x100):
        is_add = True
        is_sub = True
        is_mul = True
        is_rev_sub = True
        is_opp = True
        for j in range(0x100):
            if y[j] != (x[j]+i) % 0x10000:
                is_add = False
            if y[j] != (x[j]-i) % 0x10000:
                is_sub = False
            if y[j] != (x[j]*i) % 0x10000:
                is_mul = False
            if y[j] != (-x[j]+i) % 0x10000:
                is_rev_sub = False
            if y[j] != (-x[j]-i) % 0x10000:
                is_opp = False
        if is_add:
            return lambda x: x+z3.BitVecVal(i, 16)
        if is_sub:
            return lambda x: x-z3.BitVecVal(i, 16)
        if is_mul:
            return lambda x: x*z3.BitVecVal(i, 16)
        if is_rev_sub:
            return lambda x: -x+z3.BitVecVal(i, 16)
        if is_opp:
            return lambda x: -x-z3.BitVecVal(i, 16)
    print("Unknown function at address: ", hex(addr))
    return None
    
def generate_function_double(addr): # uVar3 = ((tables["2ca16810"])[uVar3])[uVar2]
    """
    Try a pattern on a few values, then returns a 2 input function like:
    lambda x,y: x+y; lambda x,y: x-y; lambda x,y: x*y; lambda x,y: -x+y etc.
    """
    function_addrs = [get_bytes(addr+i*0x8, 0x8)[::-1] for i in range(0x100)]
    
    table = []
    for ival in function_addrs:
        x = [i for i in range(0x100)]
        y = [execute_at(int(ival.hex(), 16), i) % 0x10000 for i in x]
        table.append(y)
    
    is_add = True
    is_sub = True
    is_mul = True
    is_rev_sub = True

    for i in range(0x100):
        for j in range(0x100):
            if table[i][j] != (i+j) % 0x10000:
                is_add = False
            if table[i][j] != (i-j) % 0x10000:
                is_sub = False
            if table[i][j] != (i*j) % 0x10000:
                is_mul = False
            if table[i][j] != (-i+j) % 0x10000:
                is_rev_sub = False
    if is_add:
        return lambda x,y: x+y
    if is_sub:
        return lambda x,y: x-y
    if is_mul:
        return lambda x,y: x*y
    if is_rev_sub:
        return lambda x,y: -x+y
    print("Unknown function at address: ", hex(addr))
    return None



def compute_target(addr, fast=False):
    result = -1
    initial_function_addrs = [(i,get_bytes(addr+i*0x8, 0x8)[::-1]) for i in range(0x10000)]
    for i,ival in initial_function_addrs:
        output = execute_at(int(ival.hex(), 16), 0x1)
        if output == 1:
            if result != -1:
                print(f"Multiple results found: {result} and {i}")
            result = i
            if fast:
                break
    return result

tables_values = {}
for addr in tqdm.tqdm(nested_functions.keys()):
    if nested_functions[addr] == 0:
        tables_values[addr] = generate_table_double(int(addr, 16))
    elif nested_functions[addr] == 1:
        tables_values[addr] = genetate_function_simple(int(addr, 16))
    elif nested_functions[addr] == 2:
        tables_values[addr] = generate_table_simple(int(addr, 16))
    elif nested_functions[addr] == 3:
        tables_values[addr] = generate_function_double(int(addr, 16))
    else:
        print(f"Nested functions: {nested_functions[addr]} not supported")

Z3 solving

conditions.py contains the function generate_conditions which generates the conditions for the Z3 solver.

from conditions import generate_conditions

s = z3.Solver()

z3_tables = {}
for key in tqdm.tqdm(tables_values.keys()):
    try:
        if nested_functions[key] == 0: # Table with values 8 -> 16
            z3_tables[key] = tables_values[key] # Function
        elif nested_functions[key] == 1: # Simple function
            z3_tables[key] = tables_values[key]
        elif nested_functions[key] == 2: # Table with values 8 -> 8
            z3_tables[key] = z3.Array(f"table_{key}", z3.BitVecSort(8), z3.BitVecSort(16))
            for i in range(0x100):
                s.add(z3_tables[key][z3.BitVecVal(i, 8)] == z3.BitVecVal(tables_values[key][i], 16))
        elif nested_functions[key] == 3: # Double function
            z3_tables[key] = tables_values[key] # Function
        else:
            print(f"Nested functions: {nested_functions[key]} not supported")

    except:
        print(f"Error with key: {key}")
        raise



input = [z3.BitVec(f"input_{i}", 8) for i in range(0x40)]

for i in range(0x40):
    s.add(input[i] >= 0x20)
    s.add(input[i] < 0x7f)

s.add(input[0] == ord("E"))
s.add(input[1] == ord("C"))
s.add(input[2] == ord("W"))
s.add(input[3] == ord("{"))
s.add(input[0x3f] == ord("}"))

input1 = input[:0x10]
input2 = input[0x10:0x20]
input3 = input[0x20:0x30]
input4 = input[0x30:0x40]

conditions, target_addrs = generate_conditions(z3_tables, input1, input2, input3, input4)

target_values = []
for target in tqdm.tqdm(target_addrs):
    target_values.append(compute_target(int(target, 16), fast=True))

for i in tqdm.tqdm(range(len(conditions))):
    s.add(conditions[i] == target_values[i])

print("Checking ...")
print(s.check())

m = s.model()
output = ""
for i in range(0x40):
    output += chr(m[input[i]].as_long())
print(output)
sat
ECW{3V3n_1nD1r3c1_c411S_AnD_M1lL10N_FUnC110ns_CaNn0T_5t0P_4_Pr0}