1const std = @import("std");
2const io = std.io;
3const builtin = @import("builtin");
4
5pub const io_mode: io.Mode = builtin.test_io_mode;
6
7var log_err_count: usize = 0;
8
9var args_buffer: [std.fs.MAX_PATH_BYTES + std.mem.page_size]u8 = undefined;
10var args_allocator = std.heap.FixedBufferAllocator.init(&args_buffer);
11
12fn processArgs() void {
13    const args = std.process.argsAlloc(args_allocator.allocator()) catch {
14        @panic("Too many bytes passed over the CLI to the test runner");
15    };
16    if (args.len != 2) {
17        const self_name = if (args.len >= 1) args[0] else if (builtin.os.tag == .windows) "test.exe" else "test";
18        const zig_ext = if (builtin.os.tag == .windows) ".exe" else "";
19        std.debug.print("Usage: {s} path/to/zig{s}\n", .{ self_name, zig_ext });
20        @panic("Wrong number of command line arguments");
21    }
22    std.testing.zig_exe_path = args[1];
23}
24
25pub fn main() void {
26    if (builtin.zig_is_stage2) {
27        return main2() catch @panic("test failure");
28    }
29    processArgs();
30    const test_fn_list = builtin.test_functions;
31    var ok_count: usize = 0;
32    var skip_count: usize = 0;
33    var fail_count: usize = 0;
34    var progress = std.Progress{
35        .dont_print_on_dumb = true,
36    };
37    const root_node = progress.start("Test", test_fn_list.len) catch |err| switch (err) {
38        // TODO still run tests in this case
39        error.TimerUnsupported => @panic("timer unsupported"),
40    };
41    const have_tty = progress.terminal != null and progress.supports_ansi_escape_codes;
42
43    var async_frame_buffer: []align(std.Target.stack_align) u8 = undefined;
44    // TODO this is on the next line (using `undefined` above) because otherwise zig incorrectly
45    // ignores the alignment of the slice.
46    async_frame_buffer = &[_]u8{};
47
48    var leaks: usize = 0;
49    for (test_fn_list) |test_fn, i| {
50        std.testing.allocator_instance = .{};
51        defer {
52            if (std.testing.allocator_instance.deinit()) {
53                leaks += 1;
54            }
55        }
56        std.testing.log_level = .warn;
57
58        var test_node = root_node.start(test_fn.name, 0);
59        test_node.activate();
60        progress.refresh();
61        if (!have_tty) {
62            std.debug.print("{d}/{d} {s}... ", .{ i + 1, test_fn_list.len, test_fn.name });
63        }
64        const result = if (test_fn.async_frame_size) |size| switch (io_mode) {
65            .evented => blk: {
66                if (async_frame_buffer.len < size) {
67                    std.heap.page_allocator.free(async_frame_buffer);
68                    async_frame_buffer = std.heap.page_allocator.alignedAlloc(u8, std.Target.stack_align, size) catch @panic("out of memory");
69                }
70                const casted_fn = @ptrCast(fn () callconv(.Async) anyerror!void, test_fn.func);
71                break :blk await @asyncCall(async_frame_buffer, {}, casted_fn, .{});
72            },
73            .blocking => {
74                skip_count += 1;
75                test_node.end();
76                progress.log("{s}... SKIP (async test)\n", .{test_fn.name});
77                if (!have_tty) std.debug.print("SKIP (async test)\n", .{});
78                continue;
79            },
80        } else test_fn.func();
81        if (result) |_| {
82            ok_count += 1;
83            test_node.end();
84            if (!have_tty) std.debug.print("OK\n", .{});
85        } else |err| switch (err) {
86            error.SkipZigTest => {
87                skip_count += 1;
88                test_node.end();
89                progress.log("{s}... SKIP\n", .{test_fn.name});
90                if (!have_tty) std.debug.print("SKIP\n", .{});
91            },
92            else => {
93                fail_count += 1;
94                test_node.end();
95                progress.log("{s}... FAIL ({s})\n", .{ test_fn.name, @errorName(err) });
96                if (!have_tty) std.debug.print("FAIL ({s})\n", .{@errorName(err)});
97                if (@errorReturnTrace()) |trace| {
98                    std.debug.dumpStackTrace(trace.*);
99                }
100            },
101        }
102    }
103    root_node.end();
104    if (ok_count == test_fn_list.len) {
105        std.debug.print("All {d} tests passed.\n", .{ok_count});
106    } else {
107        std.debug.print("{d} passed; {d} skipped; {d} failed.\n", .{ ok_count, skip_count, fail_count });
108    }
109    if (log_err_count != 0) {
110        std.debug.print("{d} errors were logged.\n", .{log_err_count});
111    }
112    if (leaks != 0) {
113        std.debug.print("{d} tests leaked memory.\n", .{leaks});
114    }
115    if (leaks != 0 or log_err_count != 0 or fail_count != 0) {
116        std.process.exit(1);
117    }
118}
119
120pub fn log(
121    comptime message_level: std.log.Level,
122    comptime scope: @Type(.EnumLiteral),
123    comptime format: []const u8,
124    args: anytype,
125) void {
126    if (@enumToInt(message_level) <= @enumToInt(std.log.Level.err)) {
127        log_err_count += 1;
128    }
129    if (@enumToInt(message_level) <= @enumToInt(std.testing.log_level)) {
130        std.debug.print("[{s}] ({s}): " ++ format ++ "\n", .{ @tagName(scope), @tagName(message_level) } ++ args);
131    }
132}
133
134pub fn main2() anyerror!void {
135    var bad = false;
136    // Simpler main(), exercising fewer language features, so that stage2 can handle it.
137    for (builtin.test_functions) |test_fn| {
138        test_fn.func() catch |err| {
139            if (err != error.SkipZigTest) {
140                bad = true;
141            }
142        };
143    }
144    if (bad) {
145        return error.TestsFailed;
146    }
147}
148