""" Parse a knowledge base and generate an AST Can be run directly with program as standard input to view/debug AST usage: python AST.py knowledge_bases/simple.hmknf OR usage: python AST.py < knowledge_bases/simple.hmknf """ from itertools import count, product from string import ascii_lowercase from sys import argv, flags, stdin from antlr4 import ParseTreeVisitor, InputStream, CommonTokenStream from grammars.HMKNFLexer import HMKNFLexer from grammars.HMKNFParser import HMKNFParser from dataclasses import dataclass from more_itertools import partition from operator import itemgetter from functools import reduce, partial from enum import Enum atom = str Set = frozenset @dataclass class LRule: head: tuple[atom, ...] pbody: tuple[atom, ...] nbody: tuple[atom, ...] def text(self): head = ", ".join(self.head) nbody = tuple(f"not {atom}" for atom in self.nbody) body = ", ".join(self.pbody + nbody) return f"{head} :- {body}." @dataclass class OFormula: def text(self): if self is OConst.TRUE: return "TRUE" elif self is OConst.FALSE: return "FALSE" raise Exception("Unknown text repr") class OConst(Enum): TRUE = OFormula() FALSE = OFormula() def text(self): return OFormula.text(self) @dataclass class OBinary(OFormula): operator: str left: OFormula right: OFormula def text(self): return f"({self.left.text()} {self.operator} {self.right.text()})" @dataclass class OAtom(OFormula): name: str def text(self): return self.name @dataclass class ONegation(OFormula): of: OFormula def text(self): return f"-{self.of.text()}" @dataclass class KB: ont: OFormula rules: tuple[LRule] katoms: Set[atom] oatoms: Set[atom] def text(self): ont = f"{self.ont.text()}.\n" rules = "\n".join(map(LRule.text, self.rules)) return ont + rules class HMKNFVisitor(ParseTreeVisitor): def __init__(self) -> None: self.katoms = set() self.oatoms = set() def visitKb(self, ctx: HMKNFParser.KbContext): orules = (self.visit(orule) for orule in ctx.orule()) if ctx.orule() else (OConst.TRUE,) ont = reduce(partial(OBinary, "&"), orules) lrules = tuple(self.visit(lrule) for lrule in ctx.lrule()) return KB(ont, lrules, Set(self.katoms), Set(self.oatoms)) def visitLrule(self, ctx: HMKNFParser.LruleContext): head = self.visit(ctx.head()) if ctx.head() else () pbody, nbody = self.visit(ctx.body()) if ctx.body() else ((), ()) return LRule(head, pbody, nbody) def visitHead(self, ctx: HMKNFParser.HeadContext): heads = tuple(str(atom) for atom in ctx.IDENTIFIER()) self.katoms.update(heads) return heads def visitBody(self, ctx: HMKNFParser.BodyContext): nbody, pbody = partition( itemgetter(1), (self.visit(body_atom) for body_atom in ctx.katom()) ) return tuple(tuple(item[0] for item in part) for part in (pbody, nbody)) def visitOrule(self, ctx: HMKNFParser.OruleContext): return self.visit(ctx.oformula()) def visitDisjOrConj(self, ctx: HMKNFParser.DisjOrConjContext): return OBinary( ctx.operator.text, self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)) ) def visitParenth(self, ctx: HMKNFParser.ParenthContext): return self.visit(ctx.oformula()) def visitImp(self, ctx: HMKNFParser.ImpContext): return OBinary("->", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))) def visitLiteral(self, ctx: HMKNFParser.LiteralContext): return self.visit(ctx.oatom()) def visitKatom(self, ctx: HMKNFParser.KatomContext): name, sign = self.visit(ctx.patom() if ctx.patom() else ctx.natom()) self.katoms.add(name) return name, sign def visitOatom(self, ctx: HMKNFParser.OatomContext): name = str(ctx.IDENTIFIER()) self.oatoms.add(name) return OAtom(name) def visitPatom(self, ctx: HMKNFParser.PatomContext): return str(ctx.IDENTIFIER()), True def visitNatom(self, ctx: HMKNFParser.NatomContext): return str(ctx.IDENTIFIER()), False def visitNegation(self, ctx: HMKNFParser.NegationContext): return ONegation(self.visit(ctx.oformula())) def check_syntax_constraints(ast: KB): for rule in ast.rules: if len(rule.head) > 1: raise Exception("No rule may have more than 1 atom in its head") return ast def unique_atom_namer(prefix: str): for i in count(1): for suffix in product(ascii_lowercase, repeat=i): suffix = "".join(suffix) yield prefix + suffix def desugar_constraints(ast: KB): used_names = list() def add_name(name: str): used_names.append(name) return name names = map(add_name, unique_atom_namer("constraint_")) rules = tuple( LRule( (name := next(names),), rule.pbody, rule.nbody + (name,), ) if not rule.head else rule for rule in ast.rules ) return KB(ast.ont, rules, ast.katoms.union(used_names), ast.oatoms) def loadProgram(fileContents) -> KB: input_stream = InputStream(fileContents) lexer = HMKNFLexer(input_stream) stream = CommonTokenStream(lexer) parser = HMKNFParser(stream) tree = parser.kb() visitor = HMKNFVisitor() ast = visitor.visit(tree) return desugar_constraints(check_syntax_constraints(ast)) def main(): if len(argv) > 1: in_file = open(argv[1], "rt", encoding="utf8") else: in_file = stdin print(loadProgram(in_file.read()).text()) if __name__ == "__main__" and not flags.interactive: main()