Sometimes, you can find patterns in seemingly random things…

Category: crypto

Solver: 3mb0, nh1729

Flag: HTB{n01sy_LF5R-1s_n0t_l0ud_3n0ugh}

Writeup

For this challenge, we were given a python script that processes the flag alongside its output.

import random
from hashlib import sha512

class LFSR:
    def __init__(self, state, taps):
        self.state = list(map(int, list("{:0128b}".format(state))))
        self.taps = taps
                
    def clock(self):
        outbit = self.state[0]
        newbit = sum([self.state[t] for t in self.taps]) & 1
        self.state = self.state[1:] + [newbit]
        return outbit

key = random.getrandbits(128)
G = LFSR(key, [0, 1, 2, 7, 3])
[G.clock() for _ in range(256)]
stream = [G.clock() for _ in range(5000)]
noise = [int(random.random() > 0.95) for _ in range(5000)]
stream = [x ^ y for x, y in zip(stream, noise)]
print(stream)

flag = open("flag.txt", "rb").read()
enc = bytes([x ^ y for x, y in zip(sha512(str(key).encode()).digest(), flag)])
print(enc.hex())

The script generates a key of 128 random bits and uses it as IV for a linear feedback shift register (LFSR) and encryptes the flag with the sha512 of the key. We get the encrypted flag and 5000 bits of output from the LFSR, which are generated after 256 clocks. However, about 5% of the 5000 bits are flipped at random before.

LFSRs are easy to compute in both forward and backward generation, so we extended the functionality to undo the clock operations

      def unclock(self):
        oldbit = sum([self.state[t-1] for t in TABS]) & 1
        output = self.state[-1]
        self.state = [oldbit] + self.state[:-1]
        return output

So to decrypt the flag, we have to initialize a LFSR with a state we reconstruct from the noisy 5000 bits and wind it back until it contains the original key. Because of the way LFSR are constructed, if they are initialized with a state S, the first len(S) bits of output will be S. Since only an expected number of 6.4 bits are flipped for every slice of 128 bits, we use the first 128 bits of output as state for our LFSR and improve from there.

A simple way of analyzing which bits might have been flipped is to analyze which bits from the output after our state are not generated from our LFSR. Since every output bit depends only on 5 state bits, the so-called tabs, of the state the LFSR was in 128 before its generation, we can thus identify single bits in our chosen state that are involved in the generation of particularly many bits that do not align with the output.

def test_approximate(s, threshold):
    error_resp = {i:0 for i in range(128)}
    l = LFSR(s[:128], TABS)
    for _ in range(128):
        l.clock()
    for i in range(128 - max(TABS)):
        wrong_bit = (s[128+i] != l.clock())
        for t in TABS:
           error_resp[t + i] += 1 if wrong_bit else -1
    filtered_resp = [i for i in error_resp if error_resp[i] >= threshold]
    s = flip_bits(s[:128], filtered_resp)      

Because of the interaction of multiple incorrect bits and the noise in the output we compare to, this is not sufficient to reconstruct the state. However, we can use it to refine the state once we eliminate some flipped bits.

From here on, there is probably an elegant way to solve the challenge. We did not feel like digging though theoretical papers though and decided to systematically test possibly flipped bits in the input stream until our refinements can find the flag.

def solve():
    for f in itertools.chain.from_iterable((itertools.combinations(range(128), i) for i in range(5))):
        print(f'Testing with flipped bits {f}')
        for t in range(5):
            test_approximate(flip_bits(stream, f), t)

Solver

#!/usr/bin/env python3

import random
from hashlib import sha512
import itertools

stream = []
ENC = b''

TABS = [0, 1, 2, 7, 3]

class LFSR:
    def __init__(self, state, taps):
        self.state = state
        self.taps = taps
                
    def clock(self):
        outbit = self.state[0]
        newbit = sum([self.state[t] for t in self.taps]) & 1
        self.state = self.state[1:] + [newbit]
        return outbit

    def unclock(self):
        oldbit = sum([self.state[t-1] for t in TABS]) & 1
        output = self.state[-1]
        self.state = [oldbit] + self.state[:-1]
        return output

    def __str__(self):
        return ''.join(map(lambda x: str(x), self.state))

def flip_bits(s, bits):
    s = s[:]
    for b in bits:
        s[b] = 1 - s[b]
    return s
    
def test_correctness(bits):
    G = LFSR(bits, TABS)
    for i in range(256):
        G.unclock()
    recovered_key = int(str(G), 2)
    flag = bytes([x ^ y for x, y in zip(sha512(str(recovered_key).encode()).digest(), ENC)])
    if b'HTB{' in flag:
        print(flag.decode())
        return True
    return False
    
def test_approximate(s, threshold):
    error_resp = {i:0 for i in range(128)}
    l = LFSR(s[:128], TABS)
    for _ in range(128):
        l.clock()
    for i in range(128 - max(TABS)):
        wrong_bit = (s[128+i] != l.clock())
        for t in TABS:
           error_resp[t + i] += 1 if wrong_bit else -1
    filtered_resp = [i for i in error_resp if error_resp[i] >= threshold]
    s = flip_bits(s[:128], filtered_resp)
    if test_correctness(s):
        print(f'Flipped bits {filtered_resp} with high error rates')
        exit(0)
    return s

def solve():
    for f in itertools.chain.from_iterable((itertools.combinations(range(128), i) for i in range(5))):
        print(f'Testing with flipped bits {f}')
        for t in range(5):
            test_approximate(flip_bits(stream, f), t)
            
if __name__ == '__main__':
    with open('output.txt', 'r') as output:
        stream = [int(c) for c in output.readline() if c in ['0', '1']]
        ENC = bytes.fromhex(output.readline())
    solve()

Shortened output (took about 90 seconds):

Testing with flipped bits (0, 108, 109)
Testing with flipped bits (0, 108, 110)
Testing with flipped bits (0, 108, 111)
HTB{n01sy_LF5R-1s_n0t_l0ud_3n0ugh}
Flipped bits [34, 54, 72, 89, 106, 114] with high error rates