from antlr4 import ParseTreeVisitor, InputStream, CommonTokenStream from grammars.HMKNFLexer import HMKNFLexer from grammars.HMKNFParser import HMKNFParser from dataclasses import dataclass from more_itertools import partition from operator import itemgetter from functools import reduce, partial from enum import Enum @dataclass class LRule: head: tuple[str, ...] pbody: tuple[str, ...] nbody: tuple[str, ...] @dataclass class OFormula: pass class OConst(Enum): TRUE = OFormula() FALSE = OFormula() @dataclass class OBinary(OFormula): operator: str left: OFormula right: OFormula @dataclass class OLiteral(OFormula): name: str sign_positive: bool @dataclass class KB: ont: OFormula rules: list[LRule] atoms: set[str] class HMKNFVisitor(ParseTreeVisitor): def __init__(self) -> None: self.atoms = set() def visitKb(self, ctx: HMKNFParser.KbContext): orules = (self.visit(orule) for orule in ctx.orule()) if ctx.orule() else () ont = reduce(partial(OBinary, "&"), orules, OConst.TRUE) lrules = tuple(self.visit(lrule) for lrule in ctx.lrule()) if ctx.lrule() else () return KB(ont, lrules, self.atoms) def visitLrule(self, ctx: HMKNFParser.LruleContext): head = self.visit(ctx.head()) if ctx.head() else () pbody, nbody = self.visit(ctx.body()) if ctx.body() else ((), ()) return LRule(head, pbody, nbody) def visitHead(self, ctx: HMKNFParser.HeadContext): return tuple(str(atom) for atom in ctx.IDENTIFIER()) def visitBody(self, ctx: HMKNFParser.BodyContext): nbody, pbody = partition( itemgetter(1), (self.visit(body_atom) for body_atom in ctx.katom()) ) return tuple(tuple(item[0] for item in part) for part in (pbody, nbody)) def visitOrule(self, ctx: HMKNFParser.OruleContext): return self.visit(ctx.oformula()) def visitDisj(self, ctx: HMKNFParser.DisjContext): return OBinary("|", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))) def visitConj(self, ctx: HMKNFParser.ConjContext): return OBinary("&", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))) def visitParenth(self, ctx: HMKNFParser.ParenthContext): return self.visit(ctx.oformula()) def visitImp(self, ctx: HMKNFParser.ImpContext): return OBinary("->", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))) def visitLiteral(self, ctx: HMKNFParser.LiteralContext): return self.visit(ctx.oatom()) def visitKatom(self, ctx: HMKNFParser.KatomContext): name, sign = self.visit(ctx.patom() if ctx.patom() else ctx.natom()) self.atoms.add(name) return name, sign def visitOatom(self, ctx: HMKNFParser.OatomContext): name, sign_positive = self.visit(ctx.patom() if ctx.patom() else ctx.fatom()) self.atoms.add(name) return OLiteral(name, sign_positive) def visitPatom(self, ctx: HMKNFParser.PatomContext): return str(ctx.IDENTIFIER()), True def visitNatom(self, ctx: HMKNFParser.NatomContext): return str(ctx.IDENTIFIER()), False def visitFatom(self, ctx: HMKNFParser.FatomContext): return str(ctx.IDENTIFIER()), False def loadProgram(fileContents): input_stream = InputStream(fileContents) lexer = HMKNFLexer(input_stream) stream = CommonTokenStream(lexer) parser = HMKNFParser(stream) tree = parser.kb() visitor = HMKNFVisitor() ast = visitor.visit(tree) return ast def main(): from sys import stdin print(loadProgram(stdin.read())) if __name__ == "__main__": main()