#!/usr/bin/env python3 import re from random import randrange from array import array from copy import copy class VM(): def __init__(self, reg: array = array('q'), mem: list = [], pc: int = 0): self.reg = copy(reg) self.mem = copy(mem) self.pc = copy(pc) self.outbuf = [] def _combo(self, c) -> int: if c == 0: return 0 if c == 1: return 1 if c == 2: return 2 if c == 3: return 3 if c == 4: return self.reg[0] if c == 5: return self.reg[1] if c == 6: return self.reg[2] raise RuntimeError() def run(self, abort: int = 0): while self.pc < len(self.mem): if abort>0 and len(self.mem) > abort: return op = self.mem[self.pc] self.pc += 1 if op == 0b000: # adv self.reg[0] = self.reg[0] // (2**self._combo(self.mem[self.pc])) self.pc += 1 if op == 0b001: # bxl self.reg[1] = self.reg[1] ^ self.mem[self.pc] self.pc += 1 if op == 0b010: # bst self.reg[1] = self._combo(self.mem[self.pc]) & 0b111 self.pc += 1 if op == 0b011: # jnz self.pc = self.mem[self.pc] if self.reg[0] != 0 else self.pc+1 if op == 0b100: # bxc self.reg[1] = self.reg[1] ^ self.reg[2] self.pc += 1 if op == 0b101: # out self.outbuf.append(self._combo(self.mem[self.pc]) & 0b111) self.pc += 1 if op == 0b110: # bdv self.reg[1] = self.reg[0] // (2**self._combo(self.mem[self.pc])) self.pc += 1 if op == 0b111: # cdv self.reg[2] = self.reg[0] // (2**self._combo(self.mem[self.pc])) self.pc += 1 def flush(self, binfmt = False): if not binfmt: print(','.join(map(str, self.outbuf))) else: for o in self.outbuf: print(f"{o:b},", end="") print() def cmp(self, prog): return self.outbuf == prog reg: array = array('q') mem: list = [] pc: int = 0 with open('day17') as f: reg_a = int(re.match('Register A: ([0-9]+)', f.readline()).groups()[0]) reg_b = int(re.match('Register B: ([0-9]+)', f.readline()).groups()[0]) reg_c = int(re.match('Register C: ([0-9]+)', f.readline()).groups()[0]) reg = array('q', [reg_a, reg_b, reg_c]) f.readline() mem_str = re.match('Program: (.*)', f.readline()).groups()[0] for c in mem_str: if c == ',': continue mem.append(int(c)) vm = VM(reg, mem, pc) vm.run() vm.flush() # enumerates all 3 bit number, then increases size by one # enumerates them again # # don't know how exactly it enumerates but it enumerate # lower bytes before changing high byte # -> we can tighten boundaries ever more once we find # valid high bytes one after the other # # but first we need to find where to even start # we know the reg_a from the puzzle input is # too small, but start from here and increas rapidly: reg_a = 17323786 while True: reg = array('q', [reg_a, reg_b, reg_c]) vm = VM(reg, mem, pc) vm.run(abort=len(mem)) if vm.cmp(mem): break # if reg_a % 1 == 0: # vm.flush(binfmt=True) if len(mem) == len(vm.outbuf): break reg_a *= 10 print(reg_a) vm.flush() # well that was fast # we get: 173237860000000 # producing a program: 4,4,4,4,1,6,0,0,0,2,1,0,0,1,3,0 # # that's long enough and even better we even got the 2 highest bytes right # already # let's find our first set of boundaries so we can start a binary search known_bytes = vm.outbuf[-2:] init_reg_a = reg_a upper_bound = -1 lower_bound = -1 reg_a = init_reg_a while True: reg = array('q', [reg_a, reg_b, reg_c]) vm = VM(reg, mem, pc) vm.run(abort=len(mem)) if vm.cmp(mem): break if vm.outbuf[-2:] != known_bytes: upper_bound = reg_a break if len(vm.outbuf) > len(mem): upper_bound = reg_a break reg_a *= 10 print(reg_a) vm.flush() reg_a = init_reg_a while True: reg = array('q', [reg_a, reg_b, reg_c]) vm = VM(reg, mem, pc) vm.run(abort=len(mem)) if vm.cmp(mem): break if vm.outbuf[-2:] != known_bytes: lower_bound = reg_a break if len(vm.outbuf) < len(mem): lower_bound = reg_a break reg_a //= 10 print(reg_a) vm.flush() # that was fast again and did not produce very enlightening bounds, # nevertheless we have bounds # # ub: 1732378600000000 # out: 4,4,0,7,0,7,6,0,3,5,1,5,4,2,4,4,2 # lb: 17323786000000 # 4,4,5,5,5,0,1,4,6,4,4,4,7,3,7