#!/usr/bin/env bash

set -euo pipefail

# Pin for reproducibility; commit ID of the most recent release of the NixOS 26.05 channel at time of writing
export NIX_PATH=nixpkgs=https://github.com/NixOS/nixpkgs/archive/a0374025a863d007d98e3297f6aa46cc3141c2f0.tar.gz

if ! [ -e initrd.gz ]; then
  rootdir=$(nix-build --no-out-link -E '
    with import <nixpkgs> {
      overlays = [ (final: prev: {
        cryptsetup = prev.cryptsetup.overrideAttrs (old: {
          version = "git-main";
          src = builtins.fetchGit {
            url = "https://gitlab.com/cryptsetup/cryptsetup.git";
            ref = "ea34c8ae85408b5d6d77f568d80f9268345ae1d4";
          };
          patches = [ ];
          nativeBuildInputs = (old.nativeBuildInputs or []) ++ [ final.autoreconfHook final.gettext ];
          outputs = [ "bin" "out" "dev" ];
        });
      }) ];
    };
    let
      kernel = linuxPackages_testing.kernel;

      modules = makeModulesClosure {
        kernel = kernel.modules;
        firmware = kernel;
        allowMissing = true;
        rootModules = [
          "ahci" "sd_mod"
          "dm_mod" "dm_crypt"
          "aes" "aesni_intel" "xts" "cbc" "sha256" "sha512"
        ];
      };
    in
    runCommand "initrd" { nativeBuildInputs = [ xz kmod ]; } '"''"'
      mkdir -p $out/bin $out/sbin

      cp -aL ${pkgs.pkgsStatic.cryptsetup}/bin/cryptsetup $out/bin/
      cp -aL ${pkgs.pkgsStatic.busybox}/bin/*             $out/bin/

      cp -a ${modules}/lib $out/lib
      chmod -R u+w $out/lib
      find $out/lib/modules -name "*.ko.xz" -exec xz -d {} +
      depmod -b $out "$(ls $out/lib/modules)"

      ln -s ../bin/busybox $out/sbin/modprobe

      cat > $out/init <<EOF
      #!/bin/sh
      set -e

      export PATH=/bin:/sbin

      mkdir -p /proc /sys /dev /run
      mount -t proc     none /proc
      mount -t sysfs    none /sys
      mount -t devtmpfs none /dev
      exec >/dev/ttyS0 2>&1

      for m in ahci sd_mod dm_mod dm_crypt aes aesni_intel xts cbc sha256 sha512; do
        modprobe \$m 2>/dev/null || true
      done

      until [ -b /dev/sda ]; do
        sleep 1
      done

      set -x

      uname -a
      cryptsetup --version

      echo -n abc | cryptsetup luksOpen /dev/sda foo
      cryptsetup luksSuspend foo
      cat /proc/keys
      echo mem > /sys/power/state
      EOF

      chmod +x $out/init
    '"''"'
  ')

  (cd "$rootdir"; find . -mindepth 1 -print0) \
    | LC_ALL=C sort -z \
    | (cd "$rootdir"; cpio --null --create --format=newc --reproducible) \
    | gzip -9 \
    > ./initrd.gz
fi

if ! [ -e bzImage ]; then
  cp "$(nix-build '<nixpkgs>' --no-out-link -A linuxPackages_testing.kernel)/bzImage" ./bzImage
fi

sha256sum -c <<EOF || exit 1
622e8e940982315dad7f3497e8b8907296c88473b3a4bb5d766cd5269b098cf0  bzImage
374342f8bde60de27fb8a96a06d3c284519b9a53bd344299361c3187fd54dcab  initrd.gz
EOF

dd if=/dev/zero of=disk.img bs=1 seek=50M count=0
echo -n abc | cryptsetup luksFormat --batch-mode --uuid 55555555-5555-5555-5555-555555555555 ./disk.img
rm -f volume-key.bin
echo -n abc | cryptsetup luksDump --dump-volume-key --volume-key-file ./volume-key.bin ./disk.img

qemu-system-x86_64 \
  -enable-kvm \
  -m 2048M \
  -nographic -serial "file:./serial.log" \
  -kernel ./bzImage \
  -initrd ./initrd.gz \
  -monitor unix:./qmon.sock,server,nowait \
  -drive file=./disk.img,format=raw,if=none,id=d0 -device ahci,id=ahci -device ide-hd,drive=d0,bus=ahci.0 \
  -append "rdinit=/init init=/init console=ttyS0" &

tail -fF --pid=$! ./serial.log &

# Wait until guest has entered suspended state
until
  echo -e "info status" | socat - UNIX-CONNECT:./qmon.sock | grep -q suspended
do
  sleep 1
done

# Dump memory
echo "pmemsave 0 0x80000000 after-suspend.raw" | socat - UNIX-CONNECT:./qmon.sock

# Create dummy file for detecting completion of the memory dump
echo "pmemsave 0 0 finished" | socat - UNIX-CONNECT:./qmon.sock
until [ -e finished ]; do
  sleep 1
done

echo "quit" | socat - UNIX-CONNECT:./qmon.sock

python3 -c '
import sys, mmap
needle = open(sys.argv[1], "rb").read()
with open(sys.argv[2], "rb") as f:
    with mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) as m:
        i = -1
        found = False
        while True:
            i = m.find(needle, i + 1)
            if i != -1:
                found = True
                print(f"Found volume key at position {hex(i)}.")
            else:
                break
        if not found:
            print("Volume key not found.")
' ./volume-key.bin ./after-suspend.raw
