From 87d4973d4890eeef184bd43967d8495463ff705f Mon Sep 17 00:00:00 2001 From: Liam Malone Date: Wed, 25 Feb 2026 18:01:39 +0000 Subject: [PATCH] init --- .gitignore | 2 + README.md | 63 ++ build.zig | 50 ++ build.zig.zon | 10 + flake.lock | 113 ++++ flake.nix | 23 + src/generator.zig | 1397 +++++++++++++++++++++++++++++++++++++++++++++ src/xml.zig | 640 +++++++++++++++++++++ 8 files changed, 2298 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 build.zig create mode 100644 build.zig.zon create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 src/generator.zig create mode 100644 src/xml.zig diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..77d4f6e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/zig-out/ +/.zig-cache/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e932c1d --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# Wayland Protocol Code Generator + +Code generation for Wayland protocols, for use without libwayland. +Generated code is NOT ABI compatible with libwayland. + +I wrote this to use in my own projects to avoid the callback-heavy +interface provided by libwayland and the standard wayland-scanner. + +This generator currently only supports generating Zig code, but I plan +to add an option to emit the code as a single-header C library too. + + +## Usage + +### CLI + +The binary can be produced by invoking `zig build` in the project root +and `protocols.zig` can be generated by running the program as below: + +``` +# The program can be run with as many input protocols as you'd like +$ ./zig-out/bin/wayland-protocol-generator -o protocols.zig path/to/wayland.xml path/to/protocol1.xml path/to/protocolN.xml +``` + +This will read in all provided xml files and produce a single `protocols.zig` +which will include the code for each protocol. + +The core wayland protocol can be found at +https://gitlab.freedesktop.org/wayland/wayland and additional protocol +specifications can be found at +https://gitlab.freedesktop.org/wayland/wayland-protocols + + +### Generation Via Build.zig + +You can use the zig build system to generate `protocols.zig` and expose it +as a module as follows: + +Add this repo as a dependency. You can do this manually or by invoking: + +```shell +$ zig fetch --save git+https://github.com/ptrToLiam/wayland-protocol-codegen +``` + +Then add some lines such as the following to your `build.zig`: + +```zig +const wayland_protocol_specifications = [_]std.Build.LazyPath{ + b.path("path/to/wayland.xml"), + b.path("path/to/protocol1.xml"), + b.path("path/to/protocolN.xml"), + ..., +}; + +const wayland_protocols = b.dependency("wayland_zig", .{ + .protocols = &wayland_protocol_specifications, +}).module("wayland-protocols"); +exe.root_module.addImport("wayland-protocols", wayland_protocols); +``` + +This will allow you to import the protocol code with +`@import("wayland-protocols")` in your executable module's source code. + diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..8be0019 --- /dev/null +++ b/build.zig @@ -0,0 +1,50 @@ +pub fn build(b: *std.Build) !void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const protocols_opt = b.option( + []std.Build.LazyPath, + "protocols", + "Paths to desired wayland protocols (e.g. wayland.xml, xdg-shell.xml, etc.)" + ); + + const debug_opt = b.option( + bool, + "debug", + "run generator with debug logging" + ) orelse false; + + const root = b.createModule(.{ + .root_source_file = b.path("src/generator.zig"), + .target = target, + .optimize = optimize + }); + + const generator = b.addExecutable(.{ + .name = "wayland-protocol-codegen", + .root_module = root, + }); + b.installArtifact(generator); + + if (protocols_opt) |protocols| { + const wl_generate_cmd = b.addRunArtifact(generator); + + for (protocols) |protocol| { + wl_generate_cmd.addFileArg(protocol); + } + + wl_generate_cmd.addArg("-o"); + + if (debug_opt) wl_generate_cmd.addArg("--debug"); + + const protocols_zig = wl_generate_cmd.addOutputFileArg("protocols.zig"); + + const protocols_zig_module = b.addModule("wayland-protocols", .{ + .root_source_file = protocols_zig, + }); + + _ = protocols_zig_module; + } +} + +const std = @import("std"); diff --git a/build.zig.zon b/build.zig.zon new file mode 100644 index 0000000..7a00174 --- /dev/null +++ b/build.zig.zon @@ -0,0 +1,10 @@ +.{ + .name = .wayland_protocol_codegen, + .version = "0.0.0", + .fingerprint = 0xe5b43b9ebc061f42, + .minimum_zig_version = "0.16.0-dev.2261+d6b3dd25a", + .paths = .{ + "build.zig", + "src", + }, +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..60fca9c --- /dev/null +++ b/flake.lock @@ -0,0 +1,113 @@ +{ + "nodes": { + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1769170682, + "narHash": "sha256-oMmN1lVQU0F0W2k6OI3bgdzp2YOHWYUAw79qzDSjenU=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c5296fdd05cfa2c187990dd909864da9658df755", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1708161998, + "narHash": "sha256-6KnemmUorCvlcAvGziFosAVkrlWZGIc6UNT9GUYr0jQ=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "84d981bae8b5e783b3b548de505b22880559515f", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-23.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs", + "zig": "zig" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "zig": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs_2" + }, + "locked": { + "lastModified": 1770165742, + "narHash": "sha256-q7nu9F2ZtmIU9BMarY8BLIEfiwPwrjXGXLvOdFSvQP4=", + "owner": "mitchellh", + "repo": "zig-overlay", + "rev": "364ec9318bc70fc1ec084cb41e80e90d6770dd2b", + "type": "github" + }, + "original": { + "owner": "mitchellh", + "repo": "zig-overlay", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..9b80c21 --- /dev/null +++ b/flake.nix @@ -0,0 +1,23 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + zig.url = "github:mitchellh/zig-overlay"; + }; + + outputs = { self, nixpkgs, zig }: + let + supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; + forEachSupportedSystem = f: nixpkgs.lib.genAttrs supportedSystems (system: f { + pkgs = import nixpkgs { inherit system; }; + }); + in + { + devShells = forEachSupportedSystem({ pkgs }: { + default = pkgs.mkShell { + packages = with pkgs; [ + zig.packages.${system}."master-2026-02-03" + ]; + }; + }); + }; +} diff --git a/src/generator.zig b/src/generator.zig new file mode 100644 index 0000000..0e967ea --- /dev/null +++ b/src/generator.zig @@ -0,0 +1,1397 @@ +pub fn main(init: std.process.Init) !void { + const allocator = init.arena.allocator(); + var threaded: Io.Threaded = .init(allocator, .{ .environ = init.minimal.environ }); + defer threaded.deinit(); + const io = threaded.io(); + var args = try init.minimal.args.iterateAllocator(allocator); + + const program_name = args.next() orelse "wayland-protocol-codegen"; + const cwd = Io.Dir.cwd(); + + var debug = false; + var out_path_opt: ?[]const u8 = null; + var first_protocol_file_opt: ?*EntryNode = null; + var last_protocol_file_opt: ?*EntryNode = null; + var protocol_count: u32 = 0; + + var arg_count: u16 = 0; + while (args.next()) |arg| : (arg_count += 1) { + if (std.mem.eql(u8, arg, "--help") or std.mem.eql(u8, arg, "-h")) { + var backing: [1024]u8 = undefined; + var stdout = Io.File.stdout().writer(io, &backing); + try stdout.interface.print(UsageMsgFmt, .{program_name}); + } else if (std.mem.eql(u8, arg, "-o") or std.mem.eql(u8, arg, "--out")) { + out_path_opt = args.next(); + } else if (std.mem.eql(u8, arg, "--debug")) { + debug = true; + } else { + const protocol_path = try allocator.create(EntryNode); + protocol_path.* = .{ + .type = .file_name, + .name = arg, + }; + sll_push_end( + protocol_path, + &first_protocol_file_opt, + &last_protocol_file_opt, + &protocol_count, + ); + } + } + + if (arg_count == 0) { + var backing: [1024]u8 = undefined; + var stdout = Io.File.stdout().writer(io, &backing); + try stdout.interface.print(UsageMsgFmt, .{program_name}); + } + + var cur_protocol_file_opt: ?*EntryNode = first_protocol_file_opt; + var output: Output = .{}; + while (cur_protocol_file_opt) |cur_protocol_file| { + cur_protocol_file_opt = cur_protocol_file.next; + if (debug) std.debug.print( + "reading protocol file :: {s}\n", + .{ cur_protocol_file.name }, + ); + try generate_zig_code_tree( + io, + allocator, + &output, + cur_protocol_file.name, + ); + } + + if (debug) std.debug.print("found {} protocols!\n", .{output.protocol_count}); + + const out_file = if (out_path_opt) |out_path| + try cwd.createFile( + io, + out_path, + .{}, + ) + else return error.FileNotFound; + var buf: [1024]u8 = undefined; + var out = out_file.writerStreaming(io, &buf); + var out_writer = &out.interface; + + var protocol_opt: ?*EntryNode = output.protocol_first; + try out_writer.print(OutputBeginMsg, .{}); + _ = try out_writer.write(WaylandGeneralTypesCodePaste); + while (protocol_opt) |protocol| : (protocol_opt = protocol.next) { + if (debug) std.debug.print( + "found {} interfaces in protocol {s}!\n", + .{ + protocol.interface_count, + protocol.name, + }, + ); + + try out_writer.print( + ProtocolBeginFmt, + .{ protocol.name }, + ); + + var interface_opt: ?*EntryNode = protocol.interface_first; + while (interface_opt) |interface| : (interface_opt = interface.next) { + try write_wl_interface( + out_writer, + allocator, + interface, + ); + } + + try out_writer.print( + ProtocolEndString, + .{}, + ); + } + + // Write Combined Enum Union + { + _ = try out_writer.write(CombinedEnumBeginMsg); + protocol_opt = output.protocol_first; + while (protocol_opt) |protocol| : (protocol_opt = protocol.next) { + var interface_opt: ?*EntryNode = protocol.interface_first; + while (interface_opt) |interface| : (interface_opt = interface.next) { + var enum_opt: ?*EntryNode = interface.enum_first; + while (enum_opt) |@"enum"| : (enum_opt = @"enum".next) { + try out_writer.print( + CombinedEnumEntryFmt, + .{ interface.name, @"enum".name, interface.name, @"enum".identifier }, + ); + } + } + } + _ = try out_writer.write(CombinedEnumEndMsg); + } + + // Write Combined Event Union + { + _ = try out_writer.write(CombinedEventBeginMsg); + protocol_opt = output.protocol_first; + while (protocol_opt) |protocol| : (protocol_opt = protocol.next) { + var interface_opt: ?*EntryNode = protocol.interface_first; + while (interface_opt) |interface| : (interface_opt = interface.next) { + var event_opt: ?*EntryNode = interface.event_first; + while (event_opt) |event| : (event_opt = event.next) { + try out_writer.print( + CombinedEventEntryFmt, + .{interface.name, event.name, interface.name, event.identifier}, + ); + } + } + } + _ = try out_writer.write(CombinedEventEndMsg); + } + + // Write Combined Object Union + { + _ = try out_writer.write(CombinedInterfaceBeginMsg); + protocol_opt = output.protocol_first; + while (protocol_opt) |protocol| : (protocol_opt = protocol.next) { + var interface_opt: ?*EntryNode = protocol.interface_first; + while (interface_opt) |interface| : (interface_opt = interface.next) { + try out_writer.print( + CombinedInterfaceEntryFmt, + .{ interface.name, interface.identifier }, + ); + } + } + + _ = try out_writer.write( + \\ + \\ pub fn message_decode(o: Object, proxy: *Proxy, op: u16, data: []const u8) Event { + \\ return switch (o) { + \\ + ); + protocol_opt = output.protocol_first; + while (protocol_opt) |protocol| : (protocol_opt = protocol.next) { + var interface_opt: ?*EntryNode = protocol.interface_first; + while (interface_opt) |interface| : (interface_opt = interface.next) { + if (interface.event_count > 0) + try out_writer.print(" .{s} => |interface_t| @TypeOf(interface_t).message_decode(proxy, op, data),\n", .{ interface.name }); + } + } + _ = try out_writer.write( + \\ else => .invalid, + \\ }; + \\ } + \\ + ); + _ = try out_writer.write(CombinedInterfaceEndMsg); + } + + // add import std for scoped log + _ = try out_writer.write(LogPaste); + + try out.flush(); +} + +fn write_wl_interface( + writer: *Io.Writer, + allocator: std.mem.Allocator, + wl_interface: *EntryNode, +) !void { + try writer.print( + InterfaceBeginFmt, + .{ wl_interface.name, wl_interface.name, wl_interface.name, wl_interface.name, wl_interface.name }, + ); + + // write requests / events + { + try writer.print( + InterfaceBeginSectionFmt, + .{ wl_interface.name, "MESSAGES" }, + ); + + var message_opt: ?*EntryNode = wl_interface.request_first; + while (message_opt) |wl_request| : (message_opt = wl_request.next) { + try write_wl_message(writer, allocator, wl_interface, wl_request); + } + + + if (wl_interface.event_count > 0) { + + message_opt = wl_interface.event_first; + while (message_opt) |wl_event| : (message_opt = wl_event.next) { + try write_wl_message(writer, allocator, wl_interface, wl_event); + } + + // write event parse + _ = try writer.write(MessageDecodeBeginMsg); + message_opt = wl_interface.event_first; + var opcode: u16 = 0; + while (message_opt) |wl_event| : (message_opt = wl_event.next) { + defer opcode += 1; + try writer.print(" {d} => {{\n", .{opcode}); + try write_wl_message_decode(writer, allocator, wl_interface.name, wl_event); + try writer.print(" }},\n", .{}); + } + _ = try writer.write(MessageDecodeEndMsg); + } + + try writer.print( + InterfaceEndSectionFmt, + .{}, + ); + } + + // write enums / bitfields + if (wl_interface.enum_count > 0) { + try writer.print( + InterfaceBeginSectionFmt, + .{ wl_interface.name, "ENUMS" }, + ); + + var enum_opt: ?*EntryNode = wl_interface.enum_first; + while (enum_opt) |wl_enum| : (enum_opt = wl_enum.next) { + try write_wl_enum(writer, wl_enum); + } + + try writer.print( + InterfaceEndSectionFmt, + .{}, + ); + } + + try writer.print( + InterfaceEndFmt, + .{ wl_interface.name, wl_interface.version.? }, + ); +} + +fn write_wl_message( + writer: *Io.Writer, + arena: std.mem.Allocator, + wl_interface: *EntryNode, + wl_message: *EntryNode, +) !void { + _ = arena; + switch (wl_message.type) { + .request => { + try writer.print( + ClientRequestBeginFmt, + .{ wl_message.identifier, wl_interface.identifier }, + ); + if (wl_message.arg_count > 0) { + try write_wl_args( + writer, + wl_message, + ); + } + try writer.print( + ClientRequestArgsEndFmt, + .{ wl_message.arg_type.?, wl_message.opcode }, + ); + + // request body + { + if (!std.mem.eql(u8, wl_message.arg_type.?, "void")) { + try writer.print( + ClientRequestObjectCreateFmt, + .{ wl_message.arg_type.? }, + ); + if (std.mem.eql(u8, wl_message.arg_type.?, "InterfaceT")) { + _ = try writer.write(ClientRequestInterfaceBindVersionWarn); + } + try write_wl_message_encode( + writer, + wl_message, + ); + try writer.print( + ClientRequestObjectStoreFmt, + .{ } + ); + } else { + try write_wl_message_encode( + writer, + wl_message, + ); + } + } + + try writer.print( + ClientRequestEndFmt, + .{}, + ); + }, + .event => { + if (wl_message.arg_count > 0) { + try writer.print( + ClientEventBeginFmt, + .{wl_message.identifier}, + ); + try write_wl_args( + writer, + wl_message, + ); + try writer.print( + ClientEventEndFmt, + .{}, + ); + } else { + try writer.print( + ClientEventBeginEmptyFmt, + .{wl_message.identifier}, + ); + } + }, + else => return error.NotAWlMessage, + } +} + +fn write_wl_enum( + writer: *Io.Writer, + wl_enum: *EntryNode, +) !void { + if (wl_enum.type == .bitfield) + try writer.print(BitfieldBeginFmt, .{wl_enum.identifier}) + else if (wl_enum.type == .@"enum") + try writer.print(EnumBeginFmt, .{wl_enum.identifier}) + else + return error.NotAWlEnum; + + try write_wl_args(writer, wl_enum); + + if (wl_enum.type == .bitfield) + try writer.print( + BitfieldEndFmt, + .{ 32 - wl_enum.arg_count }, + ) + else if (wl_enum.type == .@"enum") + try writer.print( + EnumEndFmt, + .{ wl_enum.identifier, wl_enum.identifier }, + ); +} + +fn write_wl_args( + writer: *Io.Writer, + wl_interface_entry: *EntryNode, +) !void { + switch (wl_interface_entry.type) { + .bitfield => { + var bitfield_entry_opt: ?*EntryNode = wl_interface_entry.arg_first; + while (bitfield_entry_opt) |entry| : (bitfield_entry_opt = entry.next) { + try writer.print(BitfieldEntryFmt, .{entry.identifier}); + } + }, + .@"enum" => { + var enum_entry_opt: ?*EntryNode = wl_interface_entry.arg_first; + while (enum_entry_opt) |entry| : (enum_entry_opt = entry.next) { + if (entry.value) |entry_value| { + try writer.print(EnumEntryValueFmt, .{entry.identifier, entry_value}); + } else { + try writer.print(EnumEntryNoValueFmt, .{entry.identifier}); + } + } + }, + .event => { + var event_arg_opt: ?*EntryNode = wl_interface_entry.arg_first; + while (event_arg_opt) |event_arg| : (event_arg_opt = event_arg.next) { + try writer.print( + ClientEventEntryFmt, + .{ event_arg.identifier, event_arg.arg_type.? }, + ); + } + }, + .request => { + var request_arg_opt: ?*EntryNode = wl_interface_entry.arg_first; + while (request_arg_opt) |request_arg| : (request_arg_opt = request_arg.next) { + if (request_arg.data_type == .new_id and request_arg.interface == null) { + try writer.print( + ClientEventEntryFmt, + .{ "InterfaceT", "type" }, + ); + try writer.print( + ClientEventEntryFmt, + .{ "version", "u32" }, + ); + } else if (request_arg.data_type != .new_id) { + try writer.print( + ClientEventEntryFmt, + .{ request_arg.identifier, request_arg.arg_type.? }, + ); + } + } + }, + .invalid, .file_name, .protocol, .interface, .arg => return error.InvalidNodeType, + } +} + +fn write_wl_message_encode( + writer: *Io.Writer, + wl_message: *EntryNode, +) !void { + _ = try writer.write( + MessageEncodeBeginMsg, + ); + const is_bind_fn = std.mem.eql(u8, "InterfaceT", wl_message.arg_type.?); + var arg_opt: ?*EntryNode = wl_message.arg_first; + while (arg_opt) |arg| : (arg_opt = arg.next) { + switch (arg.data_type) { + .uint => { + if (arg.type == .@"enum") { + try writer.print( + \\ .{{ .@"enum" = .{{ .{s} = {s} }} }}, + \\ + , .{ arg.arg_type.?, arg.identifier } + ); + } else if (!std.mem.eql(u8, arg.arg_type.?, "u32")) { + try writer.print( + " .{{ .uint = {s}.toInt() }},\n", + .{ arg.identifier } + ); + } else { + try writer.print( + " .{{ .uint = {s} }},\n", + .{ arg.identifier } + ); + } + }, + .string => { + try writer.print( + \\ .{{ .string = {s} }}, + \\ + , .{ arg.identifier } + ); + }, + .object => { + try writer.print( + \\ .{{ .object = {s}.toInt() }}, + \\ + , .{ arg.identifier } + ); + }, + .array => { + try writer.print( + \\ .{{ .array = {s} }}, + \\ + , .{ arg.identifier } + ); + }, + .int => { + try writer.print( + \\ .{{ .int = {s} }}, + \\ + , .{ arg.identifier } + ); + }, + .new_id => { + if (!is_bind_fn) { + try writer.print( + \\ .{{ .new_id = result.toInt() }}, + \\ + , .{ } + ); + } else { + try writer.print( + \\ .{{ .string = InterfaceT.Name }}, + \\ .{{ .uint = selected_version }}, + \\ .{{ .new_id = result.toInt() }}, + \\ + , .{ } + ); + } + }, + .fd => { + try writer.print( + \\ .{{ .fd = {s} }}, + \\ + , .{ arg.identifier } + ); + }, + .fixed => { + try writer.print( + \\ .{{ .fixed = {s} }}, + \\ + , .{ arg.identifier } + ); + }, + .invalid, .destructor => {}, + } + } + _ = try writer.write(MessageEncodeEndMsg); +} + +fn write_wl_message_decode( + writer: *Io.Writer, + allocator: std.mem.Allocator, + wl_interface_name: []const u8, + wl_message: *EntryNode, +) !void { + _ = allocator; + var arg_opt: ?*EntryNode = wl_message.arg_first; + _ = try writer.write(MessageDecodeArgsBeginMsg); + while (arg_opt) |arg| : (arg_opt = arg.next) { + const arg_undef = switch (arg.type) { + else => "undefined", + }; + _ = try writer.print(MessageDecodeArgsEntryFmt, .{ @tagName(arg.data_type), arg_undef }); + } + _ = try writer.write(MessageDecodeArgsEndMsg); + _ = try writer.write(" proxy.message_decode(&args_in, data);\n"); + + _ = try writer.print(" break :event .{{\n", .{}); + _ = try writer.print(" .{s}_{s} = ", .{ wl_interface_name, wl_message.name }); + if (wl_message.arg_count == 0) { + _ = try writer.print("{{\n", .{}); + } else { + _ = try writer.print(".{{\n", .{}); + arg_opt = wl_message.arg_first; + var arg_no: u16 = 0; + while (arg_opt) |arg| : (arg_opt = arg.next) { + defer arg_no += 1; + if (arg.data_type == .uint and !std.mem.eql(u8, arg.arg_type.?, "u32")) { + _ = try writer.print(" .{s} = .fromInt(args_in[{d}].{s}),\n", .{ arg.name, arg_no, @tagName(arg.data_type) }); + } else if (arg.data_type == .object and arg.interface != null) { + _ = try writer.print(" .{s} = .fromInt(args_in[{d}].{s}),\n", .{ arg.name, arg_no, @tagName(arg.data_type) }); + } else { + _ = try writer.print(" .{s} = args_in[{d}].{s},\n", .{ arg.name, arg_no, @tagName(arg.data_type) }); + } + } + } + _ = try writer.print(" }}\n", .{}); + _ = try writer.print(" }};\n", .{}); +} + +const MessageDecodeBeginMsg = +\\ +\\ pub fn message_decode( +\\ proxy: *Proxy, +\\ opcode: u16, +\\ data: []const u8 +\\ ) Event { +\\ return event: { +\\ switch (opcode) { +\\ +; +const MessageDecodeArgsBeginMsg = +\\ var args_in = [_]MessageArg{ +\\ +; +const MessageDecodeArgsEntryFmt = +\\ .{{ .{s} = {s} }}, +\\ +; +const MessageDecodeArgsEndMsg = +\\ }; +\\ +; +const MessageDecodeEndMsg = +\\ else => @panic("Invalid Opcode"), +\\ } +\\ }; +\\ } +\\ +; + +fn generate_zig_code_tree( + io: Io, + arena: std.mem.Allocator, + output: *Output, + spec_filename: []const u8, +) !void { + var local_arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer local_arena.deinit(); + const local_allocator = local_arena.allocator(); + + const xml_spec_data = try Io.Dir.cwd().readFileAlloc( + io, + spec_filename, + arena, + .unlimited + ); + const spec = try Xml.parse(local_allocator, xml_spec_data); + + const protocol = try arena.create(EntryNode); + try fetch_entry_data(arena, protocol, spec.root, .protocol); + + defer sll_push_end( + protocol, + &output.protocol_first, + &output.protocol_last, + &output.protocol_count, + ); + + var spec_interfaces = spec.root.findChildrenByTag("interface"); + while (spec_interfaces.next()) |spec_interface| { + const interface = try arena.create(EntryNode); + try fetch_entry_data(arena, interface, spec_interface, .interface); + defer sll_push_end( + interface, + &protocol.interface_first, + &protocol.interface_last, + &protocol.interface_count, + ); + + var spec_interface_enums = spec_interface.findChildrenByTag("enum"); + while (spec_interface_enums.next()) |spec_interface_enum| { + const @"enum" = try arena.create(EntryNode); + try fetch_entry_data(arena, @"enum", spec_interface_enum, .@"enum"); + defer sll_push_end( + @"enum", + &interface.enum_first, + &interface.enum_last, + &interface.enum_count, + ); + + var enum_entries = spec_interface_enum.findChildrenByTag("entry"); + while (enum_entries.next()) |enum_entry| { + const entry = try arena.create(EntryNode); + try fetch_entry_data(arena, entry, enum_entry, .@"arg"); + defer sll_push_end( + entry, + &@"enum".arg_first, + &@"enum".arg_last, + &@"enum".arg_count, + ); + } + } + + var spec_interface_events = spec_interface.findChildrenByTag("event"); + while (spec_interface_events.next()) |spec_interface_event| { + const event = try arena.create(EntryNode); + try fetch_entry_data(arena, event, spec_interface_event, .event); + defer sll_push_end( + event, + &interface.event_first, + &interface.event_last, + &interface.event_count, + ); + + var event_entries = spec_interface_event.findChildrenByTag("arg"); + while (event_entries.next()) |event_entry| { + const entry = try arena.create(EntryNode); + try fetch_entry_data(arena, entry, event_entry, .@"arg"); + defer sll_push_end( + entry, + &event.arg_first, + &event.arg_last, + &event.arg_count, + ); + } + } + + var request_opcode: u16 = 0; + var spec_interface_requests = spec_interface.findChildrenByTag("request"); + while (spec_interface_requests.next()) |spec_interface_request| : (request_opcode += 1) { + const request = try arena.create(EntryNode); + try fetch_entry_data(arena, request, spec_interface_request, .request); + request.opcode = request_opcode; + request.arg_type = "void"; + defer sll_push_end( + request, + &interface.request_first, + &interface.request_last, + &interface.request_count, + ); + + var request_args = spec_interface_request.findChildrenByTag("arg"); + while (request_args.next()) |request_arg| { + const arg = try arena.create(EntryNode); + try fetch_entry_data(arena, arg, request_arg, .@"arg"); + if (arg.data_type == .new_id) { + if (arg.interface == null) { + request.arg_type = "InterfaceT"; + } else { + request.arg_type = arg.interface; + } + } + sll_push_end( + arg, + &request.arg_first, + &request.arg_last, + &request.arg_count, + ); + } + } + } +} + +fn fetch_entry_data( + arena: std.mem.Allocator, + entry: *EntryNode, + element: *const Xml.Element, + @"type": EntryType, +) !void { + const entry_name = try arena.dupe(u8, element.getAttribute("name").?); + const entry_version = if (element.getAttribute("version")) |ver| + try arena.dupe(u8, ver) + else null; + const entry_summary = if (element.getAttribute("summary")) |sum| + try arena.dupe(u8, sum) + else null; + const entry_description = if (element.getCharData("description")) |desc| + try arena.dupe(u8, desc) + else null; + const entry_value = if (element.getAttribute("value")) |val| + try arena.dupe(u8, val) + else null; + const entry_interface = if (element.getAttribute("interface")) |int| + try arena.dupe(u8, int) + else null; + const entry_nullable = (element.getAttribute("nullable") != null); + const entry_type = + if (@"type" == .@"enum" and + element.getAttribute("bitfield") != null) + .bitfield + else + @"type"; + + const entry_identifier = try toIdentifier( + arena, + entry_name, + entry_type, + ); + + const entry_arg_type, const entry_data_type = arg_type: { + if (element.getAttribute("type")) |typ| { + const enum_t = element.getAttribute("enum"); + const bitfield = if (enum_t) |_| + if (element.getAttribute("bitfield")) |_| + true + else + false + else false; + + _ = bitfield; + + const data_type = std.meta.stringToEnum(DataType, typ).?; + break :arg_type + .{ + if (data_type == .object and entry_interface != null) + entry_interface + else if (data_type == .uint and enum_t != null) + try toIdentifier(arena, enum_t.?, .@"enum") + else + data_type.zigTypeString(), + data_type, + }; + } + + break :arg_type .{ null, .invalid }; + }; + + entry.* = .{ + .description = entry_description, + .version = entry_version, + .summary = entry_summary, + .nullable = entry_nullable, + .value = entry_value, + .arg_type = entry_arg_type, + .interface = entry_interface, + .name = entry_name, + .identifier = entry_identifier, + .data_type = entry_data_type, + .type = entry_type, + }; +} + +const WaylandGeneralTypesCodePaste = +\\ +\\pub const MessageArg = union(enum) { +\\ string: [:0]const u8, +\\ array: []const u8, +\\ @"enum": Enum, +\\ new_id: u32, +\\ object: u32, +\\ fixed: f32, +\\ uint: u32, +\\ int: i32, +\\ fd: i32, +\\}; +\\ +\\pub const Proxy = struct { +\\ ctx: *anyopaque, +\\ vtable: VTable, +\\ +\\ pub fn message_decode( +\\ noalias proxy: *Proxy, +\\ noalias args_out: []MessageArg, +\\ noalias message: []const u8, +\\ ) void { +\\ @call( +\\ .auto, +\\ proxy.vtable.message_decode, +\\ .{ proxy.ctx, args_out, message }, +\\ ); +\\ } +\\ +\\ pub fn message_encode( +\\ noalias proxy: *Proxy, +\\ id: u32, +\\ opcode: u16, +\\ noalias args: []const ?MessageArg, +\\ ) void { +\\ @call( +\\ .auto, +\\ proxy.vtable.message_encode, +\\ .{ proxy.ctx, id, opcode, args }, +\\ ); +\\ } +\\ +\\ pub fn get_id( +\\ proxy: *Proxy, +\\ ) u32 { +\\ return @call( +\\ .auto, +\\ proxy.vtable.get_id, +\\ .{ proxy.ctx }, +\\ ); +\\ } +\\ +\\ pub fn put_object( +\\ proxy: *Proxy, +\\ object: Object, +\\ ) void { +\\ @call( +\\ .auto, +\\ proxy.vtable.put_object, +\\ .{ proxy.ctx, object }, +\\ ); +\\ } +\\ +\\ pub fn destroy_object( +\\ proxy: *Proxy, +\\ object_id: u32, +\\ ) void { +\\ @call( +\\ .auto, +\\ proxy.vtable.destroy_object, +\\ .{ proxy.ctx, object_id }, +\\ ); +\\ } +\\ +\\ const VTable = struct { +\\ message_decode: MessageDecodeFn, +\\ message_encode: MessageEncodeFn, +\\ get_id: GetIdFn, +\\ put_object: PutObjectFn, +\\ destroy_object: DestroyObjectFn, +\\ }; +\\}; +\\ +\\pub const MessageDecodeFn = *const fn ( +\\ noalias ctx: *anyopaque, +\\ args_out: []MessageArg, +\\ noalias message: []const u8 +\\) void; +\\ +\\pub const MessageEncodeFn = *const fn ( +\\ noalias ctx: *anyopaque, +\\ id: u32, +\\ opcode: u16, +\\ noalias args: []const ?MessageArg, +\\) void; +\\ +\\pub const GetIdFn = *const fn ( +\\ noalias ctx: *anyopaque, +\\) u32; +\\ +\\pub const PutObjectFn = *const fn ( +\\ noalias ctx: *anyopaque, +\\ object: Object, +\\) void; +\\ +\\pub const DestroyObjectFn = *const fn ( +\\ noalias ctx: *anyopaque, +\\ object_id: u32, +\\) void; +\\ +\\pub fn BitfieldMixin(comptime T: type) type { +\\ const int_type = @typeInfo(T).@"struct".backing_integer.?; +\\ +\\ return struct { +\\ pub fn toInt(self: T) Int { +\\ return @bitCast(self); +\\ } +\\ pub fn fromInt(int: Int) T { +\\ return @bitCast(int); +\\ } +\\ pub fn not(self: T) T { +\\ return fromInt(~toInt(self)); +\\ } +\\ +\\ pub fn either(a: T, b: T) T { +\\ return fromInt(toInt(a) | toInt(b)); +\\ } +\\ pub fn both(a: T, b: T) T { +\\ return fromInt(toInt(a) & toInt(b)); +\\ } +\\ pub fn eql(a: T, b: T) bool { +\\ return fromInt(a) == fromInt(b); +\\ } +\\ pub fn contains(a: T, b: T) bool { +\\ return toInt(both(a, b)) == toInt(b); +\\ } +\\ pub const Int = int_type; +\\ }; +\\} +\\ +\\ +; + +const UsageMsgFmt = +\\Generate code for interacting with a set of specified Wayland protocols in +\\a less callback-heavy manner. +\\ +\\The core Wayland specification document can be obtained from +\\https://gitlab.freedesktop.org/wayland/wayland +\\and other protocols can be found at +\\https://gitlab.freedesktop.org/wayland/wayland-protocols +\\ +\\Usage: {s} [options] -o +\\Options: +\\ -h --help Show this message and exit. +\\ --debug Write unformatted source to STDOUT in error cases. +\\ -o --out Output file to write to +\\ +; + +const OutputBeginMsg = +\\// This file is generated from provided Wayland XML specifications by +\\// wayland-protocol-codegen and should NOT be edited manually. +\\ +; + +const ProtocolBeginFmt = +\\//----------------------------------------------------------------------------- +\\// BEGIN Protocol {s} +\\//----------------------------------------------------------------------------- +\\ +; + +const ProtocolEndString = +\\ +\\//----------------------------------------------------------------------------- +\\ +\\ +; + +const InterfaceBeginFmt = +\\ +\\pub const {s} = enum (u32) {{ +\\ _, +\\ +\\ pub fn object(self: {s}) Object {{ +\\ return .{{ .{s} = self }}; +\\ }} +\\ +\\ pub fn toInt(self: {s}) u32 {{ +\\ return @intFromEnum(self); +\\ }} +\\ +\\ pub fn fromInt(int: u32) {s} {{ +\\ return @enumFromInt(int); +\\ }} +\\ +; +const InterfaceBeginSectionFmt = +\\ +\\ //--------------------------------------------------------------------------- +\\ // BEGIN {s} {s} +\\ //--------------------------------------------------------------------------- +\\ +; +const InterfaceEndSectionFmt = +\\ //--------------------------------------------------------------------------- +\\ +; +const InterfaceEndFmt = +\\ +\\ pub const Name = "{s}"; +\\ pub const Version = {s}; +\\}}; +\\ +; + +const ClientRequestBeginFmt = +\\ +\\ pub fn {s}( +\\ noalias self: *const {s}, +\\ noalias proxy: *Proxy, +\\ +; +const ClientRequestArgEntryFmt = +\\ {s}: {s}, +\\ +; +const ClientRequestArgsEndFmt = +\\ ) {s} {{ +\\ const Opcode = {}; +\\ +; + +const ClientRequestObjectCreateFmt = +\\ const result: {s} = .fromInt(proxy.get_id()); +\\ +; +const ClientRequestInterfaceBindVersionWarn = +\\ +\\ if (InterfaceT.Version != version) { +\\ log.warn( +\\ "Interface {s} version mismatch :: Client expects v{} — Compositor has v{}", +\\ .{ +\\ InterfaceT.Name, +\\ InterfaceT.Version, +\\ version, +\\ }, +\\ ); +\\ } +\\ const selected_version = @min(InterfaceT.Version, version); +\\ +; +const ClientRequestObjectStoreFmt = +\\ +\\ proxy.put_object(result.object()); +\\ return result; +\\ +; +const ClientRequestEndFmt = +\\ }} +\\ +; + +const ClientEventBeginFmt = +\\ +\\ pub const {s} = struct {{ +\\ +; +const ClientEventBeginEmptyFmt = +\\ +\\ pub const {s} = void; +\\ +; +const ClientEventEntryFmt = +\\ {s}: {s}, +\\ +; +const ClientEventEndFmt = +\\ }}; +\\ +; + +const BitfieldBeginFmt = +\\ +\\ pub const {s} = packed struct (u32) {{ +\\ +; + +const BitfieldEntryFmt = +\\ {s}: bool = false, +\\ +; +const BitfieldEndFmt = +\\ +\\ __reserved_bits: u{} = 0, +\\ +\\ pub const toInt = Mixin.toInt; +\\ pub const fromInt = Mixin.fromInt; +\\ pub const not = Mixin.not; +\\ pub const either = Mixin.either; +\\ pub const both = Mixin.both; +\\ pub const eql = Mixin.eql; +\\ pub const contains = Mixin.contains; +\\ +\\ const Mixin = BitfieldMixin(@This()); +\\ }}; +\\ +; + +const EnumBeginFmt = +\\ +\\ pub const {s} = enum (u32) {{ +\\ +; +const EnumEntryValueFmt = +\\ {s} = {s}, +\\ +; +const EnumEntryNoValueFmt = +\\ {s}, +\\ +; +const EnumEndFmt = +\\ +\\ pub fn toInt(self: {s}) u32 {{ +\\ return @intFromEnum(self); +\\ }} +\\ pub fn fromInt(int: u32) {s} {{ +\\ return @enumFromInt(int); +\\ }} +\\ }}; +\\ +; + +const MessageEncodeBeginMsg = +\\ proxy.message_encode( +\\ self.toInt(), +\\ Opcode, +\\ &.{ +\\ +; +const MessageEncodeEndMsg = +\\ }, +\\ ); +\\ +; + +const CombinedInterfaceBeginMsg = +\\pub const Object = union (enum) { +\\ +; +const CombinedInterfaceEntryFmt = +\\ {s}: {s}, +\\ +; +const CombinedInterfaceEndMsg = +\\}; +\\ +; + +const CombinedEventBeginMsg = +\\pub const Event = union (enum) { +\\ invalid: void, +\\ +; +const CombinedEventEntryFmt = +\\ {s}_{s}: {s}.{s}, +\\ +; +const CombinedEventEndMsg = +\\}; +\\ +; + +const CombinedEnumBeginMsg = +\\pub const Enum = union (enum) { +\\ invalid: void, +\\ +; +const CombinedEnumEntryFmt = +\\ {s}_{s}: {s}.{s}, +\\ +; +const CombinedEnumEndMsg = +\\}; +\\ +; + +const LogPaste = +\\const log = @import("std").log.scoped(.WaylandProtocols); +\\ +; + +const DataType = enum (u32) { + invalid, + // u32 + uint, // uint may also be an enum/bitfield type + object, + new_id, + // i32 + int, + // c_int -- but sent/received in ancillary data + fd, + // fixed point float -> f32 + fixed, + // []u8 + array, + // [:0]u8 + string, + + // basically exclusive to .destroy() requests + destructor, + + pub fn zigType(comptime data_type: DataType) type { + return switch (data_type) { + .destructor, .invalid => void, + .uint, .object, .new_id => u32, + .int => i32, + .fd => c_int, + .fixed => f32, + .array => []const u8, + .string => [:0]const u8, + }; + } + pub fn zigTypeString(data_type: DataType) []const u8 { + return switch (data_type) { + .destructor, .invalid => @typeName(void), + .uint, .object, .new_id => @typeName(u32), + .int => @typeName(i32), + .fd => @typeName(c_int), + .fixed => @typeName(f32), + .array => @typeName([]const u8), + .string => @typeName([:0]const u8), + }; + } +}; + +const EntryType = enum (u32) { + invalid, + file_name, + protocol, + interface, + request, + event, + @"enum", + bitfield, + arg, +}; + +const EntryNode = struct { + next: ?*EntryNode = null, + interface_first: ?*EntryNode = null, + interface_last: ?*EntryNode = null, + interface_count: u32 = 0, + enum_first: ?*EntryNode = null, + enum_last: ?*EntryNode = null, + enum_count: u32 = 0, + event_first: ?*EntryNode = null, + event_last: ?*EntryNode = null, + event_count: u32 = 0, + request_first: ?*EntryNode = null, + request_last: ?*EntryNode = null, + request_count: u32 = 0, + arg_first: ?*EntryNode = null, + arg_last: ?*EntryNode = null, + arg_count: u32 = 0, + description: ?[]const u8 = null, + summary: ?[]const u8 = null, + interface: ?[]const u8 = null, + version: ?[]const u8 = null, + value: ?[]const u8 = null, + arg_type: ?[]const u8 = null, + name: []const u8, + opcode: u16 = 0, + identifier: []const u8 = undefined, + type: EntryType = .invalid, + data_type: DataType = .invalid, + nullable: bool = false, +}; + +const Output = struct { + protocol_first: ?*EntryNode = null, + protocol_last: ?*EntryNode = null, + protocol_count: u32 = 0, +}; + +fn sll_push_front(node: *EntryNode, first: *?*EntryNode, last: *?*EntryNode, count: *u32) void { + if (first.*) |first_old| { + first.* = node; + node.next = first_old; + } else { + first.* = node; + last.* = node; + } + count.* += 1; +} + +fn sll_push_end(node: *EntryNode, first: *?*EntryNode, last: *?*EntryNode, count: *u32) void { + if (last.*) |last_old| { + last_old.next = node; + last.* = node; + } else { + first.* = node; + last.* = node; + } + count.* += 1; + } + +fn sll_remove(node: *EntryNode, first: *?*EntryNode, last: *?*EntryNode, count: *u32) void { + if (first.* == null and last.* == null) return; + + if (first.* != null and first.*.? == node) { + first.* = node.next; + node.next = null; + count.* -= 1; + if (count.* == 0) last.* = null; + } else { + var iter: ?*EntryNode = first.*; + while (iter) |cur| : (iter = cur.next) { + if (cur.next) |next| if (next == node) { + cur.next = node.next; + if (last.*.? == node) last.* = cur; + node.next = null; + count.* -= 1; + }; + } + } +} + +fn is_digit(char: u8) bool { + return (char >= '0' and char <= '9'); +} + +pub fn toIdentifier( + allocator: std.mem.Allocator, + string: []const u8, + entry_type: EntryType, +) ![]const u8 { + std.debug.assert(string.len > 0); + + if (entry_type == .@"enum" or entry_type == .bitfield) { + const name = try snake_to_pascal(allocator, string); + return name; + } + + if (is_digit(string[0]) or Keywords.has(string)) { + return try std.fmt.allocPrint(allocator, "@\"{s}\"", .{ string }); + } + + return string; +} + +fn snake_to_pascal(allocator: std.mem.Allocator, str: []const u8) ![]const u8 { + + const underscore_count, const namespace_end = uc_ne: { + var n_end: usize = 0; + var count: u32 = 0; + for (str, 0..) |char, idx| { + if (char == '_') count += 1; + if (n_end == 0 and char == '.') { n_end = idx; count = 0; } + } + break :uc_ne .{ count, if (n_end > 0) n_end else null }; + }; + + var out_str = try allocator.alloc(u8, str.len - underscore_count); + var next_is_upper = true; + var out_idx: u32 = 0; + + for (str) |char| { + if (namespace_end) |ne| { + if (out_idx > ne) { + if (char == '_') { + next_is_upper = true; + } else { + defer out_idx += 1; + if (next_is_upper) { + next_is_upper = false; + out_str[out_idx] = std.ascii.toUpper(char); + } else out_str[out_idx] = char; + } + } else { + defer out_idx += 1; + out_str[out_idx] = char; + } + } else { + if (char == '_') { + next_is_upper = true; + } else { + defer out_idx += 1; + if (next_is_upper) { + next_is_upper = false; + out_str[out_idx] = std.ascii.toUpper(char); + } else out_str[out_idx] = char; + } + } + } + + return out_str; +} + +const Io = std.Io; +const Keywords = std.zig.Token.keywords; + +const Xml = @import("xml.zig"); +const std = @import("std"); diff --git a/src/xml.zig b/src/xml.zig new file mode 100644 index 0000000..e4b3584 --- /dev/null +++ b/src/xml.zig @@ -0,0 +1,640 @@ +// Snektron's xml.zig from his vulkan-zig project + +const std = @import("std"); +const mem = std.mem; +const testing = std.testing; +const Allocator = mem.Allocator; +const ArenaAllocator = std.heap.ArenaAllocator; + +pub const Attribute = struct { + name: []const u8, + value: []const u8, +}; + +pub const Content = union(enum) { + char_data: []const u8, + comment: []const u8, + element: *Element, +}; + +pub const Element = struct { + tag: []const u8, + attributes: []Attribute = &.{}, + children: []Content = &.{}, + + pub fn getAttribute(self: Element, attrib_name: []const u8) ?[]const u8 { + for (self.attributes) |child| { + if (mem.eql(u8, child.name, attrib_name)) { + return child.value; + } + } + + return null; + } + + pub fn getCharData(self: Element, child_tag: []const u8) ?[]const u8 { + const child = self.findChildByTag(child_tag) orelse return null; + if (child.children.len != 1) { + return null; + } + + return switch (child.children[0]) { + .char_data => |char_data| char_data, + else => null, + }; + } + + pub fn iterator(self: Element) ChildIterator { + return .{ + .items = self.children, + .i = 0, + }; + } + + pub fn elements(self: Element) ChildElementIterator { + return .{ + .inner = self.iterator(), + }; + } + + pub fn findChildByTag(self: Element, tag: []const u8) ?*Element { + var it = self.findChildrenByTag(tag); + return it.next(); + } + + pub fn findChildrenByTag(self: Element, tag: []const u8) FindChildrenByTagIterator { + return .{ + .inner = self.elements(), + .tag = tag, + }; + } + + pub const ChildIterator = struct { + items: []Content, + i: usize, + + pub fn next(self: *ChildIterator) ?*Content { + if (self.i < self.items.len) { + self.i += 1; + return &self.items[self.i - 1]; + } + + return null; + } + }; + + pub const ChildElementIterator = struct { + inner: ChildIterator, + + pub fn next(self: *ChildElementIterator) ?*Element { + while (self.inner.next()) |child| { + if (child.* != .element) { + continue; + } + + return child.*.element; + } + + return null; + } + }; + + pub const FindChildrenByTagIterator = struct { + inner: ChildElementIterator, + tag: []const u8, + + pub fn next(self: *FindChildrenByTagIterator) ?*Element { + while (self.inner.next()) |child| { + if (!mem.eql(u8, child.tag, self.tag)) { + continue; + } + + return child; + } + + return null; + } + }; +}; + +pub const Document = struct { + arena: ArenaAllocator, + xml_decl: ?*Element, + root: *Element, + + pub fn deinit(self: Document) void { + var arena = self.arena; // Copy to stack so self can be taken by value. + arena.deinit(); + } +}; + +const Parser = struct { + source: []const u8, + offset: usize, + line: usize, + column: usize, + + fn init(source: []const u8) Parser { + return .{ + .source = source, + .offset = 0, + .line = 0, + .column = 0, + }; + } + + fn peek(self: *Parser) ?u8 { + return if (self.offset < self.source.len) self.source[self.offset] else null; + } + + fn consume(self: *Parser) !u8 { + if (self.offset < self.source.len) { + return self.consumeNoEof(); + } + + return error.UnexpectedEof; + } + + fn consumeNoEof(self: *Parser) u8 { + std.debug.assert(self.offset < self.source.len); + const c = self.source[self.offset]; + self.offset += 1; + + if (c == '\n') { + self.line += 1; + self.column = 0; + } else { + self.column += 1; + } + + return c; + } + + fn eat(self: *Parser, char: u8) bool { + self.expect(char) catch return false; + return true; + } + + fn expect(self: *Parser, expected: u8) !void { + if (self.peek()) |actual| { + if (expected != actual) { + return error.UnexpectedCharacter; + } + + _ = self.consumeNoEof(); + return; + } + + return error.UnexpectedEof; + } + + fn eatStr(self: *Parser, text: []const u8) bool { + self.expectStr(text) catch return false; + return true; + } + + fn expectStr(self: *Parser, text: []const u8) !void { + if (self.source.len < self.offset + text.len) { + return error.UnexpectedEof; + } else if (mem.startsWith(u8, self.source[self.offset..], text)) { + var i: usize = 0; + while (i < text.len) : (i += 1) { + _ = self.consumeNoEof(); + } + + return; + } + + return error.UnexpectedCharacter; + } + + fn eatWs(self: *Parser) bool { + var ws = false; + + while (self.peek()) |ch| { + switch (ch) { + ' ', '\t', '\n', '\r' => { + ws = true; + _ = self.consumeNoEof(); + }, + else => break, + } + } + + return ws; + } + + fn expectWs(self: *Parser) !void { + if (!self.eatWs()) return error.UnexpectedCharacter; + } + + fn currentLine(self: Parser) []const u8 { + var begin: usize = 0; + if (mem.lastIndexOfScalar(u8, self.source[0..self.offset], '\n')) |prev_nl| { + begin = prev_nl + 1; + } + + const end = mem.indexOfScalarPos(u8, self.source, self.offset, '\n') orelse self.source.len; + return self.source[begin..end]; + } +}; + +test "xml: Parser" { + { + var parser = Parser.init("I like pythons"); + try testing.expectEqual(@as(?u8, 'I'), parser.peek()); + try testing.expectEqual(@as(u8, 'I'), parser.consumeNoEof()); + try testing.expectEqual(@as(?u8, ' '), parser.peek()); + try testing.expectEqual(@as(u8, ' '), try parser.consume()); + + try testing.expect(parser.eat('l')); + try testing.expectEqual(@as(?u8, 'i'), parser.peek()); + try testing.expectEqual(false, parser.eat('a')); + try testing.expectEqual(@as(?u8, 'i'), parser.peek()); + + try parser.expect('i'); + try testing.expectEqual(@as(?u8, 'k'), parser.peek()); + try testing.expectError(error.UnexpectedCharacter, parser.expect('a')); + try testing.expectEqual(@as(?u8, 'k'), parser.peek()); + + try testing.expect(parser.eatStr("ke")); + try testing.expectEqual(@as(?u8, ' '), parser.peek()); + + try testing.expect(parser.eatWs()); + try testing.expectEqual(@as(?u8, 'p'), parser.peek()); + try testing.expectEqual(false, parser.eatWs()); + try testing.expectEqual(@as(?u8, 'p'), parser.peek()); + + try testing.expectEqual(false, parser.eatStr("aaaaaaaaa")); + try testing.expectEqual(@as(?u8, 'p'), parser.peek()); + + try testing.expectError(error.UnexpectedEof, parser.expectStr("aaaaaaaaa")); + try testing.expectEqual(@as(?u8, 'p'), parser.peek()); + try testing.expectError(error.UnexpectedCharacter, parser.expectStr("pytn")); + try testing.expectEqual(@as(?u8, 'p'), parser.peek()); + try parser.expectStr("python"); + try testing.expectEqual(@as(?u8, 's'), parser.peek()); + } + + { + var parser = Parser.init(""); + try testing.expectEqual(parser.peek(), null); + try testing.expectError(error.UnexpectedEof, parser.consume()); + try testing.expectEqual(parser.eat('p'), false); + try testing.expectError(error.UnexpectedEof, parser.expect('p')); + } +} + +pub const ParseError = error{ + IllegalCharacter, + UnexpectedEof, + UnexpectedCharacter, + UnclosedValue, + UnclosedComment, + InvalidName, + InvalidEntity, + InvalidStandaloneValue, + NonMatchingClosingTag, + InvalidDocument, + OutOfMemory, +}; + +pub fn parse(backing_allocator: Allocator, source: []const u8) !Document { + var parser = Parser.init(source); + return try parseDocument(&parser, backing_allocator); +} + +fn parseDocument(parser: *Parser, backing_allocator: Allocator) !Document { + var doc = Document{ + .arena = ArenaAllocator.init(backing_allocator), + .xml_decl = null, + .root = undefined, + }; + + errdefer doc.deinit(); + + const allocator = doc.arena.allocator(); + + try skipComments(parser, allocator); + + doc.xml_decl = try parseElement(parser, allocator, .xml_decl); + _ = parser.eatWs(); + try skipComments(parser, allocator); + + doc.root = (try parseElement(parser, allocator, .element)) orelse return error.InvalidDocument; + _ = parser.eatWs(); + try skipComments(parser, allocator); + + if (parser.peek() != null) return error.InvalidDocument; + + return doc; +} + +fn parseAttrValue(parser: *Parser, alloc: Allocator) ![]const u8 { + const quote = try parser.consume(); + if (quote != '"' and quote != '\'') return error.UnexpectedCharacter; + + const begin = parser.offset; + + while (true) { + const c = parser.consume() catch return error.UnclosedValue; + if (c == quote) break; + } + + const end = parser.offset - 1; + + return try unescape(alloc, parser.source[begin..end]); +} + +fn parseEqAttrValue(parser: *Parser, alloc: Allocator) ![]const u8 { + _ = parser.eatWs(); + try parser.expect('='); + _ = parser.eatWs(); + + return try parseAttrValue(parser, alloc); +} + +fn parseNameNoDupe(parser: *Parser) ![]const u8 { + // XML's spec on names is very long, so to make this easier + // we just take any character that is not special and not whitespace + const begin = parser.offset; + + while (parser.peek()) |ch| { + switch (ch) { + ' ', '\t', '\n', '\r' => break, + '&', '"', '\'', '<', '>', '?', '=', '/' => break, + else => _ = parser.consumeNoEof(), + } + } + + const end = parser.offset; + if (begin == end) return error.InvalidName; + + return parser.source[begin..end]; +} + +fn parseCharData(parser: *Parser, alloc: Allocator) !?[]const u8 { + const begin = parser.offset; + + while (parser.peek()) |ch| { + switch (ch) { + '<' => break, + else => _ = parser.consumeNoEof(), + } + } + + const end = parser.offset; + if (begin == end) return null; + + return try unescape(alloc, parser.source[begin..end]); +} + +fn parseContent(parser: *Parser, alloc: Allocator) ParseError!Content { + if (try parseCharData(parser, alloc)) |cd| { + return Content{ .char_data = cd }; + } else if (try parseComment(parser, alloc)) |comment| { + return Content{ .comment = comment }; + } else if (try parseElement(parser, alloc, .element)) |elem| { + return Content{ .element = elem }; + } else { + return error.UnexpectedCharacter; + } +} + +fn parseAttr(parser: *Parser, alloc: Allocator) !?Attribute { + const name = parseNameNoDupe(parser) catch return null; + _ = parser.eatWs(); + try parser.expect('='); + _ = parser.eatWs(); + const value = try parseAttrValue(parser, alloc); + + const attr = Attribute{ + .name = try alloc.dupe(u8, name), + .value = value, + }; + return attr; +} + +const ElementKind = enum { + xml_decl, + element, +}; + +fn parseElement(parser: *Parser, alloc: Allocator, comptime kind: ElementKind) !?*Element { + const start = parser.offset; + + const tag = switch (kind) { + .xml_decl => blk: { + if (!parser.eatStr(" blk: { + if (!parser.eat('<')) return null; + const tag = parseNameNoDupe(parser) catch { + parser.offset = start; + return null; + }; + break :blk tag; + }, + }; + + var attributes = try std.ArrayList(Attribute).initCapacity(alloc, 256); + defer attributes.deinit(alloc); + + var children = try std.ArrayList(Content).initCapacity(alloc, 256); + defer children.deinit(alloc); + + while (parser.eatWs()) { + const attr = (try parseAttr(parser, alloc)) orelse break; + try attributes.append(alloc, attr); + } + + switch (kind) { + .xml_decl => try parser.expectStr("?>"), + .element => { + if (!parser.eatStr("/>")) { + try parser.expect('>'); + + while (true) { + if (parser.peek() == null) { + return error.UnexpectedEof; + } else if (parser.eatStr("'); + } + }, + } + + const element = try alloc.create(Element); + element.* = .{ + .tag = try alloc.dupe(u8, tag), + .attributes = try attributes.toOwnedSlice(alloc), + .children = try children.toOwnedSlice(alloc), + }; + return element; +} + +test "xml: parseElement" { + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + + { + var parser = Parser.init("<= a='b'/>"); + try testing.expectEqual(@as(?*Element, null), try parseElement(&parser, alloc, .element)); + try testing.expectEqual(@as(?u8, '<'), parser.peek()); + } + + { + var parser = Parser.init(""); + const elem = try parseElement(&parser, alloc, .element); + try testing.expectEqualSlices(u8, elem.?.tag, "python"); + + const size_attr = elem.?.attributes[0]; + try testing.expectEqualSlices(u8, size_attr.name, "size"); + try testing.expectEqualSlices(u8, size_attr.value, "15"); + + const color_attr = elem.?.attributes[1]; + try testing.expectEqualSlices(u8, color_attr.name, "color"); + try testing.expectEqualSlices(u8, color_attr.value, "green"); + } + + { + var parser = Parser.init("test"); + const elem = try parseElement(&parser, alloc, .element); + try testing.expectEqualSlices(u8, elem.?.tag, "python"); + try testing.expectEqualSlices(u8, elem.?.children[0].char_data, "test"); + } + + { + var parser = Parser.init("bdf"); + const elem = try parseElement(&parser, alloc, .element); + try testing.expectEqualSlices(u8, elem.?.tag, "a"); + try testing.expectEqualSlices(u8, elem.?.children[0].char_data, "b"); + try testing.expectEqualSlices(u8, elem.?.children[1].element.tag, "c"); + try testing.expectEqualSlices(u8, elem.?.children[2].char_data, "d"); + try testing.expectEqualSlices(u8, elem.?.children[3].element.tag, "e"); + try testing.expectEqualSlices(u8, elem.?.children[4].char_data, "f"); + try testing.expectEqualSlices(u8, elem.?.children[5].comment, "g"); + } +} + +test "xml: parse prolog" { + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const a = arena.allocator(); + + { + var parser = Parser.init(""); + try testing.expectEqual(@as(?*Element, null), try parseElement(&parser, a, .xml_decl)); + try testing.expectEqual(@as(?u8, '<'), parser.peek()); + } + + { + var parser = Parser.init(""); + const decl = try parseElement(&parser, a, .xml_decl); + try testing.expectEqualSlices(u8, "aa", decl.?.getAttribute("version").?); + try testing.expectEqual(@as(?[]const u8, null), decl.?.getAttribute("encoding")); + try testing.expectEqual(@as(?[]const u8, null), decl.?.getAttribute("standalone")); + } + + { + var parser = Parser.init(""); + const decl = try parseElement(&parser, a, .xml_decl); + try testing.expectEqualSlices(u8, "ccc", decl.?.getAttribute("version").?); + try testing.expectEqualSlices(u8, "bbb", decl.?.getAttribute("encoding").?); + try testing.expectEqualSlices(u8, "yes", decl.?.getAttribute("standalone").?); + } +} + +fn skipComments(parser: *Parser, alloc: Allocator) !void { + while ((try parseComment(parser, alloc)) != null) { + _ = parser.eatWs(); + } +} + +fn parseComment(parser: *Parser, alloc: Allocator) !?[]const u8 { + if (!parser.eatStr("")) { + _ = parser.consume() catch return error.UnclosedComment; + } + + const end = parser.offset - "-->".len; + return try alloc.dupe(u8, parser.source[begin..end]); +} + +fn unescapeEntity(text: []const u8) !u8 { + const EntitySubstition = struct { text: []const u8, replacement: u8 }; + + const entities = [_]EntitySubstition{ + .{ .text = "<", .replacement = '<' }, + .{ .text = ">", .replacement = '>' }, + .{ .text = "&", .replacement = '&' }, + .{ .text = "'", .replacement = '\'' }, + .{ .text = """, .replacement = '"' }, + }; + + for (entities) |entity| { + if (mem.eql(u8, text, entity.text)) return entity.replacement; + } + + return error.InvalidEntity; +} + +fn unescape(arena: Allocator, text: []const u8) ![]const u8 { + const unescaped = try arena.alloc(u8, text.len); + + var j: usize = 0; + var i: usize = 0; + while (i < text.len) : (j += 1) { + if (text[i] == '&') { + const entity_end = 1 + (mem.indexOfScalarPos(u8, text, i, ';') orelse return error.InvalidEntity); + unescaped[j] = try unescapeEntity(text[i..entity_end]); + i = entity_end; + } else { + unescaped[j] = text[i]; + i += 1; + } + } + + return unescaped[0..j]; +} + +test "xml: unescape" { + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const a = arena.allocator(); + + try testing.expectEqualSlices(u8, "test", try unescape(a, "test")); + try testing.expectEqualSlices(u8, "ad\"e'f<", try unescape(a, "a<b&c>d"e'f<")); + try testing.expectError(error.InvalidEntity, unescape(a, "python&")); + try testing.expectError(error.InvalidEntity, unescape(a, "python&&")); + try testing.expectError(error.InvalidEntity, unescape(a, "python&test;")); + try testing.expectError(error.InvalidEntity, unescape(a, "python&boa")); +} + +test "xml: top level comments" { + var arena = ArenaAllocator.init(testing.allocator); + defer arena.deinit(); + const a = arena.allocator(); + + const doc = try parse(a, ""); + try testing.expectEqualSlices(u8, "python", doc.root.tag); +}