diff --git a/config.json b/config.json index 4222aa6..eedfe3d 100644 --- a/config.json +++ b/config.json @@ -40,4 +40,5 @@ } } ] -} \ No newline at end of file +} + diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..0deca53 --- /dev/null +++ b/shell.nix @@ -0,0 +1,23 @@ +{ pkgs ? import {} }: + +pkgs.mkShell { + buildInputs = [ + pkgs.gcc + pkgs.python311Full + pkgs.python311Packages.virtualenv + pkgs.python311Packages.pyudev + pkgs.python311Packages.inotify-simple + pkgs.python311Packages.psutil + pkgs.python311Packages.pyudev + ]; + + shellHook = '' + if [ ! -d .venv ]; then + virtualenv .venv + source .venv/bin/activate + else + source .venv/bin/activate + fi + echo "Welcome to your Python development environment." + ''; +} diff --git a/testdata/config.json b/testdata/config.json new file mode 100644 index 0000000..41a1baa --- /dev/null +++ b/testdata/config.json @@ -0,0 +1,50 @@ +{ + "usb": { + "hotplug_rules": { + "denylist": { + "0xbadb" : ["0xdada"], + "~0xbabb" : ["0xcaca"] + }, + + "allowlist" : { + "0x0b95:0x1790" : ["net-vm"] + }, + + "classlist" : { + "0x01:*:*" : ["audio-vm"], + "0x03:*:0x01" : ["gui-vm"], + "0x03:*:0x02" : ["gui-vm"], + "0x08:0x06:*" : ["gui-vm"], + "0x0b:*:*" : ["gui-vm"], + "0x11:*:*" : ["gui-vm"], + "0xe0:0x01:0x01" : ["gui-vm"], + "0x02:06:*" : ["net-vm"], + "0x0e:*:*" : ["chrome-vm"] + } + }, + "static_devices" : [ + { + "hostbus":null, + "hostport":null, + "name":"crazyradio1", + "productId":"0101", + "vendorId":"1915", + "vmUdevExtraRule":null, + "vms":["gui-vm"] + }, + { + "hostbus":null, + "hostport":null, + "name":"crazyflie0", + "productId": "5740", + "vendorId":"0483", + "vmUdevExtraRule":null, + "vms":["test-vm"] + } + ] + }, + "eventDevices": { + "targetVM": "gui-vm", + "pcieBusPrefix":"rp" + } +} diff --git a/vhotplug/device.py b/vhotplug/device.py index c4296e1..7daeb0f 100644 --- a/vhotplug/device.py +++ b/vhotplug/device.py @@ -2,6 +2,9 @@ import fcntl import struct import psutil +import os +import sys +import time from vhotplug.qemulink import * EVIOCGRAB = 0x40044590 @@ -9,6 +12,20 @@ logger = logging.getLogger("vhotplug") +def wait_target_vm(qmp_socket, timeout=30, interval=3): + start = time.time() + while True: + if os.path.exists(qmp_socket): + logger.debug(f" Found qmpSocket {qmp_socket} ...") + break + if time.time() - start > timeout: + logger.debug(f"Timeout! qmpSocket {qmp_socket} not found.") + return True + logger.debug(f"Waiting for qmpSocket {qmp_socket} ...") + time.sleep(interval) + return False + + def log_device(device, level=logging.DEBUG): try: logger.log(level, f"Device path: {device.device_path}") @@ -116,6 +133,9 @@ async def attach_usb_device(context, config, device, use_vid_pid): vm_name = vm.get("name") qmp_socket = vm.get("qmpSocket") logger.info(f"Attaching to {vm_name} ({qmp_socket})") + if wait_target_vm(qmp_socket): + logger.warning(f"VM:{vm_name} timeout! Couldn't retrieve {qmp_socket}") + return if is_boot_device(context, device): logger.info(f"USB drive {device.device_node} is used as a boot device, skipping") return @@ -143,6 +163,9 @@ async def remove_usb_device(config, device): async def attach_evdev_device(vm, busprefix, pcieport, device): vm_name = vm.get("name") qmp_socket = vm.get("qmpSocket") + if wait_target_vm(qmp_socket): + logger.warning(f"VM:{vm_name} timeout! Couldn't retrieve {qmp_socket}") + return bus = f"{busprefix}{pcieport}" logger.info(f"Attaching evdev device to {vm_name} ({qmp_socket}) on bus {bus}") qemu = QEMULink(qmp_socket) diff --git a/vhotplug/ghaf_dynamic_policy.py b/vhotplug/ghaf_dynamic_policy.py new file mode 100644 index 0000000..0431ddc --- /dev/null +++ b/vhotplug/ghaf_dynamic_policy.py @@ -0,0 +1,125 @@ +import subprocess +import logging +import socket +import time +import json +import os + +logger = logging.getLogger("vhotplug") + +class GhafDynamicPolicy: + def __init__(self, admin_name, admin_addr, admin_port, policy_query, givc_cli, cert = None, key = None, cacert = None): + + self.policy_json = None + self.admin_name = admin_name + self.admin_addr = admin_addr + self.admin_port = str(admin_port) + if cert is not None: + if not os.path.exists(cert): + raise FileNotFoundError(f"File {cert} does not exist.") + if not os.path.exists(key): + raise FileNotFoundError(f"File {key} does not exist.") + if not os.path.exists(cacert): + raise FileNotFoundError(f"File {cacert} does not exist.") + if not os.path.exists(givc_cli): + raise FileNotFoundError(f"File {givc_cli} does not exist.") + self.policy_query_cmd = [ + givc_cli, + "--cert", cert, + "--key", key, + "--cacert", cacert, + "--name", self.admin_name, + "--addr", self.admin_addr, + "--port", self.admin_port, + "policy-query", f"{policy_query}" + ] + else: + self.policy_query_cmd = [ + givc_cli, + "--notls", + "--name", self.admin_name, + "--addr", self.admin_addr, + "--port", self.admin_port, + "policy-query", f"{policy_query}" + ] + + + def __remove_comments(self, json_as_string): + result = "" + for line in json_as_string.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith('#'): + continue + # Remove inline comment + code_part = line.split('#', 1)[0].rstrip() + if code_part: + result += code_part + "\n" + return result + + def __wait_for_admin(self, timeout=60, interval=2): + logger.info("Waiting for admin vm to become reachable...") + + end_time = time.time() + timeout + while time.time() < end_time: + try: + with socket.create_connection((self.admin_addr, self.admin_port), timeout=3): + logger.info(f"Admin vm [{self.admin_addr}:{self.admin_port}] is reachable.") + return True + except (socket.timeout, ConnectionRefusedError, OSError): + logger.info(f"Admin vm [{self.admin_addr}:{self.admin_port}] is still not reachable.") + time.sleep(interval) + + logger.error(f"Admin vm [{self.admin_addr}:{self.admin_port}] is not reachable. Timed out after {timeout} seconds!") + return False + + def __fetch_hotplug_policy(self): + if self.__wait_for_admin() == None: + return None + try: + result = subprocess.run( + self.policy_query_cmd, + capture_output=True, + text=True, + check=True, + encoding='utf-8' + ) + output_string = result.stdout.strip() + logger.debug(f"Raw USB Hotplug Policy received:\n{output_string}") + + if not output_string: + logger.error("Error: Policy fetcher command returned empty output.") + return None + + try: + outer = json.loads(output_string) + inner = None + if isinstance(outer, str): + inner = json.loads(outer) + else: + inner = outer + + if isinstance(inner, dict) and "result" in inner: + self.policy_json = inner["result"] + else: + logger.error("Policy fetcher command returned unexpected output.") + return None + return self.policy_json + + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON from command output. JSONDecodeError: {e}") + logger.error(f"Raw output was:\n---\n{output_string}\n---") + return None + + except subprocess.CalledProcessError as e: + logger.error(f"Command execution failed with exit code {e.returncode}.") + logger.error(f"Stderr:\n---\n{result.stderr}\n---") + return None + + def get_policy(self): + if self.policy_json == None: + return self.__fetch_hotplug_policy() + else: + return self.policy_json + + def reload_policy(self): + self.policy_json = None diff --git a/vhotplug/ghaf_policy.py b/vhotplug/ghaf_policy.py new file mode 100644 index 0000000..83d27e2 --- /dev/null +++ b/vhotplug/ghaf_policy.py @@ -0,0 +1,307 @@ +import json +import logging +import os +import threading +import pprint +from vhotplug.device import * + +logger = logging.getLogger("vhotplug") + +class GhafPolicy: + def __init__(self, policy_path): + self.lock = threading.Lock() + self.extra_allowed = None; + with open(policy_path, 'r') as file: + json_data = json.load(file) + usb_rules = json_data["usb"] + self.evdev_hotplug_rules = json_data["eventDevices"] + + self.usb_hotplug_rules = usb_rules.get("hotplug_rules", {}) + self.usb_extra_devices = usb_rules.get("static_devices", []); + logger.debug(f"{self.evdev_hotplug_rules} \n{self.usb_hotplug_rules} \n{self.usb_extra_devices}") + self.denylist = self.usb_hotplug_rules.get("denylist", {}) + self.allowlist = self.usb_hotplug_rules.get("allowlist", {}) + self.class_rules = self.usb_hotplug_rules.get("classlist", {}) + self.allow_static_devices() + self.vm_list = None + + def allow_static_devices(self, force = False): + for device in self.usb_extra_devices: + vendor = device.get("vendorId", None) + product = device.get("productId", None) + vms = device.get("vms", None) + if vendor is not None and product is not None and vms is not None: + vendor_product = f"0x{vendor}:0x{product}" + + if force == True: + self.allowlist[vendor_product] = vms + elif vendor_product not in self.allowlist: + self.allowlist[vendor_product] = vms + else: + logger.info(f"Product is already in allowlist: {vendor_product} allowed VMs are {self.allowlist[vendor_product]}") + + def update_policy(self, policy): + force_static_devices = False + with self.lock: + self.usb_hotplug_rules = policy + if "denylist" in self.usb_hotplug_rules: + self.denylist = self.usb_hotplug_rules.get("denylist", self.denylist) + + if "allowlist" in self.usb_hotplug_rules: + self.allowlist = self.usb_hotplug_rules.get("allowlist", self.allowlist) + + if "classlist" in self.usb_hotplug_rules: + self.class_rules = self.usb_hotplug_rules.get("classlist", self.class_rules) + + if "static_devices" in self.usb_hotplug_rules: + self.usb_extra_devices = self.usb_hotplug_rules.get("static_devices") + force_static_devices = True + + self.allow_static_devices(force_static_devices) + self.vm_list = None + + def vm_for_evdev_devices(self): + vm = {} + busPrefix = None + if "pcieBusPrefix" in self.evdev_hotplug_rules and "targetVM" in self.evdev_hotplug_rules: + vm_name = self.evdev_hotplug_rules["targetVM"] + busPrefix = self.evdev_hotplug_rules["pcieBusPrefix"] + vm["name"] = vm_name + vm["qmpSocket"] = f"/var/lib/microvms/{vm_name}/{vm_name}.sock" + + return vm, busPrefix + + def vm_for_usb_device(self, vid, pid, vendor_name, product_name, interfaces): + with self.lock: + try: + logger.info(f"Searching for a VM for {vid}:{pid}, {vendor_name}:{product_name}") + usb_interfaces = parse_usb_interfaces(interfaces) + for interface in usb_interfaces: + device_class = interface["class"] + subclass = interface["subclass"] + protocol = interface["protocol"] + vendor = f"0x{vid}".lower() + product = f"0x{pid}".lower() + dclass = f"{device_class:#04x}".lower() + sclass = f"{subclass:#04x}".lower() + protoc = f"{protocol:#04x}".lower() + vms = self.get_allowed_vms(dclass, sclass, protoc, vendor, product) + vm = {} + if len(vms): + if len(vms) > 1: + logger.warning(f"More than one VM can access this device. Passing through to vm: {vms[0]}.") + vm["name"] = vms[0] + vm["qmpSocket"] = f"/var/lib/microvms/{vms[0]}/{vms[0]}.sock" + return vm + else: + return None + except Exception as e: + logger.error(f"Failed to find VM for USB device in the configuration file: {e}") + return None + + def get_all_vms(self): + if self.vm_list is not None: + return self.vm_list + self.vm_list = [] + vms_by_name = [] + if "targetVM" in self.evdev_hotplug_rules: + vmname = self.evdev_hotplug_rules["targetVM"] + vms_by_name.append(vmname) + for _, vms in self.allowlist.items(): + for vm in vms: + if vm not in vms_by_name: + vms_by_name.append(vm) + + for _, vms in self.class_rules.items(): + for vm in vms: + if vm not in vms_by_name: + vms_by_name.append(vm) + + for vmname in vms_by_name: + self.vm_list.append( {"name": vmname, + "qmpSocket":f"/var/lib/microvms/{vmname}/{vmname}.sock"}) + return self.vm_list + + def lookup(self, allowlist: dict, key: any) -> list: + return allowlist.get(key, []) + + def not_allowed(self, vendor_id: any, product_id: any) -> bool: + disallowed_products = self.denylist.get(vendor_id) + if disallowed_products is not None: + return product_id in disallowed_products + else: + neg_vendor = f"~{vendor_id}" + allowed_products = self.denylist.get(neg_vendor) + if allowed_products is not None: + return product_id not in allowed_products + else: + return False + + def get_allowed_vms(self, device_class: int, subclass: int, protocol: int, vendor_id: int, product_id: int): + # Check if the device is not allowed + if self.not_allowed(vendor_id, product_id): + return [] + + # Check if the device is mapped to a specific VM + device_key_0 = f"{vendor_id}:{product_id}" + device_key_1 = f"{vendor_id}:*" + vms_by_device = self.lookup(self.allowlist, device_key_0) + self.lookup(self.allowlist, device_key_1) + if len(vms_by_device) > 0: + unique_vms_by_device = list(dict.fromkeys(vms_by_device)) + return unique_vms_by_device + + # Based on class, subclass, and protocol find list VMs which can access it + class_key_0 = f"{device_class}:{subclass}:{protocol}" + class_key_1 = f"{device_class}:{subclass}:*" + class_key_2 = f"{device_class}:*:{protocol}" + class_key_3 = f"{device_class}:*:*" + cl_01_vms = self.lookup(self.class_rules, class_key_0) + self.lookup(self.class_rules, class_key_1) + cl_23_vms = self.lookup(self.class_rules, class_key_2) + self.lookup(self.class_rules, class_key_3) + vms_by_class = cl_01_vms + cl_23_vms + + # Merge VMs from all above rules + if len(vms_by_class) > 0: + unique_vms_by_class = list(dict.fromkeys(vms_by_class)) + else: + unique_vms_by_class = [] + + return unique_vms_by_class + + +############TESTS############### +class UnitTest: + def __init__(self): + current_dir = os.path.dirname(os.path.abspath(__file__)) + static_policy = os.path.join(current_dir, '../testdata/', 'config.json') + self.policy = GhafPolicy(static_policy) + + def print_vmlist(self): + pprint.pprint("VM LIST:") + pprint.pprint(self.policy.get_all_vms()) + res = self.policy.vm_for_evdev_devices() + vm = res[0] + prefix = res[1] + print("\nEVDEV Passthrough:") + pprint.pprint("VM:") + pprint.pprint(vm) + pprint.pprint("PCI Bus Prefix:") + pprint.pprint(prefix) + print("\n") + + + def compare_results(self, list1, list2): + if len(list1) == len(list2): + for elm in list1: + if elm not in list2: + return "❌ FAIL" + return "✅ PASS" + return "❌ FAIL" + + def remove_comments(self, json_as_string): + result = "" + for line in json_as_string.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith('#'): + continue + # Remove inline comment + code_part = line.split('#', 1)[0].rstrip() + if code_part: + result += code_part + "\n" + return result + + def run_test(self, test_id, device_class, subclass, vendor_id, product_id, protocol, expected_vms): + vms = self.policy.get_allowed_vms( + device_class=device_class, + subclass=subclass, + vendor_id=vendor_id, + product_id=product_id, + protocol=protocol + ) + result = self.compare_results(expected_vms, vms) + print(f"{test_id}: expected: {str(expected_vms):<30} received: {str(vms):<30} Result: {result}") + + +if __name__ == "__main__": + # To run this unittest comment this line 'from vhotplug.device import *' + unittest = UnitTest() + unittest.print_vmlist() + + unittest.run_test( + test_id="TEST1", + device_class="0xff", + subclass="0x01", + vendor_id="0x0b95", + product_id="0x1790", + protocol=0, + expected_vms=['net-vm'] + ) + + unittest.run_test( + test_id="TEST2", + device_class="0x01", + subclass="0x02", + vendor_id="0xdead", + product_id="0xbeef", + protocol="0x01", + expected_vms=['audio-vm'] + ) + + unittest.run_test( + test_id="TEST3", + device_class="0x0e", + subclass="0x02", + vendor_id="0x04f2", + product_id="0xb751", + protocol="0x01", + expected_vms=["chrome-vm"] + ) + + unittest.run_test( + test_id="TEST4", + device_class="0x0e", + subclass="0x02", + vendor_id="0x04f2", + product_id="0xb755", + protocol="0x01", + expected_vms=["chrome-vm"] + ) + + unittest.run_test( + test_id="TEST5", + device_class="0xe0", + subclass="0x01", + vendor_id="0x04f2", + product_id="0xb755", + protocol="0x01", + expected_vms=["gui-vm"] + ) + + unittest.run_test( + test_id="TEST6", + device_class="0xe0", + subclass="0x01", + vendor_id="0xbadb", + product_id="0xdada", + protocol="0x01", + expected_vms=[] + ) + + unittest.run_test( + test_id="TEST7", + device_class="0xe0", + subclass="0x01", + vendor_id="0xbabb", + product_id="0xcaca", + protocol="0x01", + expected_vms=["gui-vm"] + ) + + unittest.run_test( + test_id="TEST8", + device_class="0xe0", + subclass="0x01", + vendor_id="0xbabb", + product_id="0xb755", + protocol="0x01", + expected_vms=[] + ) diff --git a/vhotplug/vhotplug.py b/vhotplug/vhotplug.py index 6e290fe..9bdd7ca 100644 --- a/vhotplug/vhotplug.py +++ b/vhotplug/vhotplug.py @@ -7,6 +7,8 @@ from vhotplug.device import * from vhotplug.config import * from vhotplug.filewatcher import * +from vhotplug.ghaf_policy import GhafPolicy +from vhotplug.ghaf_dynamic_policy import GhafDynamicPolicy logger = logging.getLogger("vhotplug") @@ -33,13 +35,67 @@ async def device_event(context, config, device): if device.subsystem == 'power_supply': logger.info(f"Power supply device {device.sys_name} changed, this may indicate a system resume") +def handle_config(args): + if not os.path.exists(args.config): + logger.error(f"Configuration file {args.config} not found") + raise FileNotFoundError(f"The {args.config} file was not found.") + + if args.opa and args.config != "": + logger.error(f"Ghaf Policy and/or OPA can not be enabled with config.") + raise ValueError("Ghaf Policy and/or OPA can not be enabled with config.") + + config = Config(args.config) + return config + + +def handle_policy(args): + if not os.path.exists(args.policy): + logger.error(f"Policy file {args.policy} not found") + raise FileNotFoundError(f"The {args.config} file was not found.") + + policy = GhafPolicy(args.policy) + + if args.opa: + if not args.policy_query or not args.admin_addr: + parser.error("--policy-query and --admin-addr must be specified when --opa is enabled.") + raise ValueError("--policy-query and --admin-addr must be specified when --opa is enabled.") + if args.notls: + dynamic_policy = GhafDynamicPolicy( + admin_name = args.admin_name, + admin_addr = args.admin_addr, + admin_port = args.admin_port, + policy_query = args.policy_query, + givc_cli = "/run/current-system/sw/bin/givc-cli"); + else: + dynamic_policy = GhafDynamicPolicy( + admin_name = args.admin_name, + admin_addr = args.admin_addr, + admin_port = args.admin_port, + policy_query = args.policy_query, + givc_cli = "/run/current-system/sw/bin/givc-cli", + cert = "/etc/givc/cert.pem", + key = "/etc/givc/key.pem", + cacert = "/etc/givc/ca-cert.pem"); + policy.update_policy(dynamic_policy.get_policy()) + + return policy + async def async_main(): parser = argparse.ArgumentParser(description="Hot-plugging USB devices to the virtual machines") - parser.add_argument("-c", "--config", type=str, required=True, help="Path to the configuration file") + parser.add_argument("-c", "--config", type=str, default="", help="Path to the configuration file") + parser.add_argument("-p", "--policy", type=str, default="", help="Path to policy file") parser.add_argument("-a", "--attach-connected", default=False, action=argparse.BooleanOptionalAction, help="Attach connected devices on startup") parser.add_argument("-d", "--debug", default=False, action=argparse.BooleanOptionalAction, help="Enable debug messages") + parser.add_argument("--opa", action='store_true', help="Pull OPA policy for USB hotplus") + parser.add_argument("--notls", action='store_true', help="Dosable TLS for givc communication") + parser.add_argument("--admin-name", type=str, default="admin-vm", help="Name of Admin vm") + parser.add_argument("--admin-addr", type=str, default="", help="Address of admin-vm") + parser.add_argument("--admin-port", type=int, default=9001, help="Port of admin-vm") + parser.add_argument("--policy-query", type=str, default="", help="Policy query to send to admin-vm") + args = parser.parse_args() + vhotplugrules = None handler = logging.StreamHandler() handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) logger.addHandler(handler) @@ -48,20 +104,19 @@ async def async_main(): else: logger.setLevel(logging.INFO) - if not os.path.exists(args.config): - logger.error(f"Configuration file {args.config} not found") - return - - config = Config(args.config) + if args.config != "": + vhotplugrules = handle_config(args) + else: + vhotplugrules = handle_policy(args) context = pyudev.Context() if args.attach_connected: - await attach_connected_devices(context, config) + await attach_connected_devices(context, vhotplugrules) monitor = pyudev.Monitor.from_netlink(context) watcher = FileWatcher() - for vm in config.get_all_vms(): + for vm in vhotplugrules.get_all_vms(): qmp_socket = vm.get("qmpSocket") watcher.add_file(qmp_socket) @@ -70,9 +125,9 @@ async def async_main(): while True: device = monitor.poll(timeout=1) if device != None: - await device_event(context, config, device) + await device_event(context, vhotplugrules, device) if watcher.detect_restart() == True and args.attach_connected: - await attach_connected_devices(context, config) + await attach_connected_devices(context, vhotplugrules) except KeyboardInterrupt: logger.info("Ctrl+C")