Zadanie

Zadanie z dnia 17 polega na tym żeby uruchomić program na fikcyjnym procesorze opisanym dokładnie w treści zadania. Omówię jedynie najważniejsze aspekty. Nasz procesor operuje na liczbach 3 bitowych, każda operacja składa się z kodu operacji zaraz za nią operandu (para zajmuje razem 6 bitów). Procesor posiada też 3 rejestry: A, B i C, rejestry te mogą trzymać dowolne wartości nie ograniczone przez 3 bity. Program składa się z ciągu 3 bitowych wartości, operator operand, IP jest rejestrem który na początku ma wartość 0, wskazuje on na instrukcję która powinna zostać wykonana następna. Moment w którym próbujemy wykonać instrukcją znajdującą się poza pamięcią jest momentem kiedy nasz program się kończy.

W treści zadania zdefiniowane jest 8 instrukcji które głównie manipulują rejestrami. Nie będę ich wszystkich opisywał zachęcam do zapoznania się z treścią na oficjalnej stronie Advent of Code.

Na wejściu dostajemy takie informacje:

Register A: 2024
Register B: 0
Register C: 0

Program: 0,3,5,4,3,0

Jedna z instrukcji odpowiedzialna jest za wypisywanie wartości na ekranie, powinniśmy zasymulować działanie programu i wskazać co ten program wypisał. (wypisane wartości powinny być odseparowane po przecinku)

Rozwiązanie

Zadanie nie wydaje się trudne, trzeba po prostu zdefiniować sobie stan procesora i zaimplementować wszystkie instrukcje. Późniejsze wykonanie programu polega jedynie na odpowiednim zmienianiu wartości IP i zakończeniu pracy jak wyjdzie ona poza program.

Parsowanie

Nie ma tu nic specjalnego, znowu używam wyrażeń regularnych:

def parse(data: str) -> State:
    PATTERN = (
        r"Register A: (\d+)\nRegister B: (\d+)\nRegister C: (\d+)\n\nProgram: ([0-9,]+)"
    )
    m = re.match(PATTERN, data)
    if m is None:
        raise ValueError("Invalid input")
    reg_A, reg_B, reg_C, program = m.groups()
    return State(int(reg_A), int(reg_B), int(reg_C), list(map(int, program.split(","))))


if __name__ == "__main__":
    state = parse(sys.stdin.read())

Definicja stanu procesora

Wartoci rejestrów (A, B, C, IP) i program załadowany do pamięci jednoznacznie definiują aktualny stan proceosra, dodałem też pomocniczo atrybut output, jeśli będzie on pusty to nie będę wypisywał przecinka przed liczbą na wyjściu.

@dataclass
class State:
    reg_A: int
    reg_B: int
    reg_C: int
    program: list[int]  # 3-bit numbers
    ip: int = 0

    output = ""

Symulacja programu

Niektóre operacje wykorzystują operand typu literal, czyli po prostu wartość od 0 to 7, z kolei inne korzystają z typu combo, działa on tak że wartości od 0 do 3 są traktowanet tak samo jak literał, zaś wartości od 4 do 6 biorą jako argument do operacji wartość z danego rejestru, wartość 7 nie powinna się pojawić nigdy w combo

Zdefiniujmy combo jako atrybut stanu:

def combo(self, operand: int) -> int:
    if operand >= 0 and operand <= 3:
        return operand
    elif operand == 4:
        return self.reg_A
    elif operand == 5:
        return self.reg_B
    elif operand == 6:
        return self.reg_C
    else:
        raise ValueError("Invalid operand")

Teraz zdefiniuję metodę która dokonuje jednego kroku programu zgodnie ze specyfikacją, funkcja ta zwróci też wartość logiczną wskazującą na to czy program się zakończył czy nie:

    def step(self) -> bool:
        if self.ip >= len(self.program):
            return True

        opcode, operand = self.program[self.ip : self.ip + 2]
        match opcode:
            case 0: # adv - operacja dzielenia całkowitego przez 2^x jest 
                    # równoznaczna z przesunięciem bitowym w prawo o x
                self.reg_A >>= self.combo(operand)
            case 1:  # bxl
                self.reg_B ^= operand
            case 2:  # bst
                self.reg_B = self.combo(operand) % 8
            case 3:  # jnz
                if self.reg_A != 0:
                    self.ip = operand
                    return False
            case 4:  # bxc
                self.reg_B ^= self.reg_C
            case 5:  # out
                if self.output:
                    self.output += ","
                self.output += str(self.combo(operand) % 8)
            case 6:  # bdv - podobnie jak adv
                self.reg_B = self.reg_A >> self.combo(operand)
            case 7:  # cdv - podobnie jak adv
                self.reg_C = self.reg_A >> self.combo(operand)
            case _:
                print("Unknown opcode")
        self.ip += 2
        return False

Cały kod prezentuje się tak:

import re
import sys
from dataclasses import dataclass


@dataclass
class State:
    reg_A: int
    reg_B: int
    reg_C: int
    program: list[int]  # 3-bit numbers
    ip: int = 0

    output = ""

    def step(self) -> bool:
        if self.ip >= len(self.program):
            return True

        opcode, operand = self.program[self.ip : self.ip + 2]
        match opcode:
            case 0:  # adv
                self.reg_A >>= self.combo(operand)
            case 1:  # bxl
                self.reg_B ^= operand
            case 2:  # bst
                self.reg_B = self.combo(operand) % 8
            case 3:  # jnz
                if self.reg_A != 0:
                    self.ip = operand
                    return False
            case 4:  # bxc
                self.reg_B ^= self.reg_C
            case 5:  # out
                if self.output:
                    self.output += ","
                self.output += str(self.combo(operand) % 8)
            case 6:  # bdv
                self.reg_B = self.reg_A >> self.combo(operand)
            case 7:  # cdv
                self.reg_C = self.reg_A >> self.combo(operand)
            case _:
                print("Unknown opcode")
        self.ip += 2
        return False

    def combo(self, operand: int) -> int:
        if operand >= 0 and operand <= 3:
            return operand
        elif operand == 4:
            return self.reg_A
        elif operand == 5:
            return self.reg_B
        elif operand == 6:
            return self.reg_C
        else:
            raise ValueError("Invalid operand")


def parse(data: str) -> State:
    PATTERN = (
        r"Register A: (\d+)\nRegister B: (\d+)\nRegister C: (\d+)\n\nProgram: ([0-9,]+)"
    )
    m = re.match(PATTERN, data)
    if m is None:
        raise ValueError("Invalid input")
    reg_A, reg_B, reg_C, program = m.groups()
    return State(int(reg_A), int(reg_B), int(reg_C), list(map(int, program.split(","))))


if __name__ == "__main__":
    state = parse(sys.stdin.read())

    while not state.step():
        pass

    print(state.output)

Po uruchomieniu tego kodu otrzymujemy porpawną wartość.

Część druga

Tutaj spotykamy się z ciekawym twistem, jak można zauważyć, output programu jest też poprawnym programem. Mamy znaleźć taką początkową wartość rejestru A, aby program wypisał samego siebie.

Z początku nie wiedziałem zupełnie jak do tego zadania podejść dlatego zajęło mi ono więcej niż poprzednie, brute force nie wchodzi raczej w grę gdyż wartość rejestru może być dowolnie wysoka.

Uznałem że najlepiej będzie spróbować dokonać czegoś na kształt inżynierii wstecznej, najpierw spróbowałem przekonwertować program z kodu maszynowego na nieco bardziej czytelną dla człowieka formę.

Zauważyłem że zarówno mój input jak i testowy input kończą się instrukcją jnz 0, czyli sprawdzane jest czy A jest różne od zera i jeśli nie to cofamy się na początek pętli. Możemy potraktować to jako pętlę do while.

Dalsze fakty dotyczą mojego programu i różnią się od testowego więc nie mam pewności czy u innych wyglądało to tak samo, przedostatnią instrukcją u mnie jest adv 3, czyli po każdej iteracji nasze A będzie dzielona przez 8, myśląc na zasadzie operacji bitowych, zostanie ono przesunięte w prawo o 3 bity (ucinamy końcówkę liczby w zapisie binarnym)

Wcześniej mamy jeszcze out 5, które wypisuje na ekranie wartość rejestru B (ucięta do ostatnich trzech bitów).

Cały program jest równoznaczny do takiego pseudokodu:

do
    b = a & 7
    b ^= 5
    c = a >> b
    b ^= 6
    b ^= c
    print b & 7
    a >>= 3
while a != 0

Co ciekawe każda iteracja tej pętli zależy tylko od wartości rejestru A, wartości B i C są nadpisywane w każdej iteracji (linijka 2 i 4).

Jeszcze jeden ciekawy fakt: A zmienia się tylko w jednym miejscu i za każdym razem w niezmienny sposób. Możemy myśleć że cały program iteruje po trójkach bitów w A. Przykład: w na początku ostatniej iteracji wiemy że maksymalnie 3 bity mogły być ustawione na 1 w rejestrze A, gdyż jeśli byłoby ich więcej to program nie zakończyłby się w tej iteracji. (A nie byłoby równe 0 po przesunięciu bitowym). Z kolei w drugiej iteracji będzie maksymalnie 6 bitów i tak dalej.

Jednak wypisywana wartość nie jest zależna jedynie od 3 bitów A, ważne są też wcześniejsze bity przez instrukcję c = a >> b, właściwie to b w tym miejscu ma wartość od 0 do 7 przez instrukcję b = a & 7, więc tak na prawdę musielibyśmy tylko patrzeć na ostatnie 11 bitów, żeby określić jednoznacznie jaką wartość wypisze jedna iteracja programu. (przesunięcie o 8 a później ucinane jest do 3 bitów, 8 + 3 = 11).

Do brzegu.

Wiedząc te wszystkie rzeczy możemy zacząć analizować od końca, wartości które program powinien, zwrócić. Będą one zależne od najbardziej znaczących bitów w A. Napiszę funkcję rekurencyjną w tym celu, która dla każdej rozważaniej trójki bitów, będzie analizować wszystkie możliwości w rosnącej kolejności.

program = [..., 3, 3, 0]

def iteration(a):
    b = a & 7
    b = b ^ 5
    c = a >> b
    b = b ^ 6
    b = b ^ c
    return b & 7

def solve(current, i):
    # current - aktualna wartość A, i - pozycja w outpucie którego chcemy
    if i < 0:  # znaleźliśmy wszystkie poszukiwane wartości 
        return tuple()

    # przesuwamy aktualną wartość A w lewo, żeby zrobić miejsce dla nowych bitów
    current <<= 3
    for j in range(8): # sprawdzamy wszystkie opcje o 000 do 111
        candidate = current + j
        res = iteration(candidate)
        if res == program[i]:
            rest = solve(candidate, i - 1)  # szukamy kolejnych wartości
            if rest is not None:
                return (j,) + rest
    else:
        return None  # jest możliwość że danego outputu nie da się osiągnąć
                     # tym programem

Funkcja solve zwraca nam tupla z poszczególnymi wartościami trójek bitów w rejestrze A które sprawią że program zwróci samego siebie. Możemy je przekonwertować do jednej liczby w taki sposób:

triples = solve(0, len(program) - 1)
if triples is not None:
    res = 0
    for triple in triples:
        res <<= 3
        res += triple
    print(res)
else:
    print("Not found")

Kod natychmiastowo zwraca poprawny wynik

Podsumowanie

Zadanie było na prawdę trudne, najwięcej czasu spędziłem patrząc się w kartkę w zeszucie próbując zrozumieć naturę tego fikcyjnego programu i w szukaniu jakiegoś wzoru albo zależności, ostatecznie rozwiązanie nie było oczywiste i bardzo ciekawe.