import argparse
import os


class TestResult(object):
    def __init__(self, file):
        self.tests_results = {}
        self.failed_tests = {}
        self.section = None
        self.test_name = None
        self.messages = []
        self.tests = 0
        self.file = file
        self.parse_file(file)

    def parse_file(self, file):
        with open(file, "r") as f:
            lines = f.readlines()
            for line in lines:
                self.parse_line(line)

    def parse_line(self, line):
        if self.section is None and line.startswith("running"):
            self.section = "test_results"
        elif self.section == "test_results":
            if line.startswith("failures:"):
                self.section = "failed_tests"
            elif line.startswith("test "):
                self.tests += 1
                line = line[len("test "):]
                segments = line.strip().split(' ... ')
                if len(segments) >= 2:
                    self.test_name = segments[0]
                    for keyword in ["ok", "FAILED", "ignored"]:
                        if segments[1].startswith(keyword):
                            self.tests_results[self.test_name] = keyword
                            self.test_name = None
                            break
            elif line.strip() in ["ok", "FAILED"]:
                self.tests_results[self.test_name] = line.strip()
                self.test_name = None
        elif self.section == "failed_tests":
            if line.startswith("failures:"):
                self.section = "failed_tests_summary"
                self.on_failures_done()
            elif line.startswith("---- "):
                if self.test_name and len(self.messages) > 0:
                    self.failed_tests[self.test_name] = self.messages
                self.test_name = line[len("---- "): len(line) - len(' stdout ----')]
                self.messages = []
            elif len(line.strip()) > 0:
                self.messages.append(line.strip())

    def on_failures_done(self):
        if self.test_name and len(self.messages) > 0:
            self.failed_tests[self.test_name] = self.messages
        self.test_name = None
        self.messages = []


def compare_test_results(test_result, expected_result):
    missing_tests = []
    extra_tests = []
    different_tests = []
    for test_name, result in test_result.tests_results.items():
        if test_name not in expected_result.tests_results:
            extra_tests.append(test_name)
        elif test_result.tests_results[test_name] != expected_result.tests_results[test_name]:
            different_tests.append(test_name)
    for test_name in expected_result.tests_results.keys():
        if test_name not in test_result.tests_results:
            missing_tests.append(test_name)
    return missing_tests, extra_tests, different_tests

def count_results(test_result):
    passed = 0
    failed = 0
    ignored = 0
    for test_name, result in test_result.tests_results.items():
        if result == "ok":
            passed += 1
        elif result == "FAILED":
            failed += 1
        elif result == "ignored":
            ignored += 1
    return passed, failed, ignored

def process_test_result(file, expected, test_name):
    test_result = TestResult(file)
    expected_result = TestResult(expected)
    missing_tests, extra_tests, different_tests = compare_test_results(test_result, expected_result)
    output = []
    output.append(f"## Test results for {test_name}:")
    output.append("")
    passed, failed, ignored = count_results(test_result)
    output.append(f"| Passed   | Failed   | Ignored  |")
    output.append( "| -------- | -------- | -------- |")
    output.append(f"| {str(passed).ljust(7)} | {str(failed).ljust(7)} | {str(ignored).ljust(7)} |")
    if len(missing_tests) == 0 and len(extra_tests) == 0 and len(different_tests) == 0:
        output.append("### ✅ All tests matched expected results!")
        success = True
    else:
        output.append("### ❌ Tests did not match expected results!")
        if len(missing_tests) > 0:
            output.append(f"#### Missing tests: {len(missing_tests)}")
            for test_name in missing_tests:
                output.append(f"  - {test_name}")
        if len(extra_tests) > 0:
            output.append(f"#### Extra tests: {len(extra_tests)}")
            for test_name in extra_tests:
                output.append(f"  - {test_name}")
        if len(different_tests) > 0:
            output.append(f"#### Different tests: {len(different_tests)}")
            for test_name in different_tests:
                output.append(f"  - {test_name}: {test_result.tests_results[test_name]}, expected: {expected_result.tests_results[test_name]}")
        success = False
    return output, success


def run():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--file", type=str, required=True)
    parser.add_argument("-d", "--dump", action="store_true", default=False)
    parser.add_argument("-e", "--expected", type=str)
    parser.add_argument("-t", "--test-name", type=str, required=True)
    args = parser.parse_args()
    output, success = process_test_result(args.file, args.expected, args.test_name)
    if args.dump:
        for line in output:
            print(line)
    else:
        step_summary = os.environ.get("GITHUB_STEP_SUMMARY")
        if step_summary:
            with open(step_summary, "a") as f:
                for line in output:
                    f.write(line + "\n")
        else:
            print("No step summary found")
    exit(0 if success else 1)


if __name__ == "__main__":
    run()