aft-may25-2023/AST.py

219 lines
5.7 KiB
Python
Raw Permalink Normal View History

2023-05-24 19:05:16 -06:00
"""
Parse a knowledge base and generate an AST
Can be run directly with program as standard input to view/debug AST
usage: python AST.py knowledge_bases/simple.hmknf
OR
usage: python AST.py < knowledge_bases/simple.hmknf
"""
from itertools import count, product
from string import ascii_lowercase
from sys import argv, flags, stdin
from antlr4 import ParseTreeVisitor, InputStream, CommonTokenStream
from grammars.HMKNFLexer import HMKNFLexer
from grammars.HMKNFParser import HMKNFParser
from dataclasses import dataclass
from more_itertools import partition
from operator import itemgetter
from functools import reduce, partial
from enum import Enum
atom = str
Set = frozenset
@dataclass
class LRule:
head: tuple[atom, ...]
pbody: tuple[atom, ...]
nbody: tuple[atom, ...]
def text(self):
2023-05-25 14:37:24 -06:00
head = ", ".join(self.head)
2023-05-24 19:05:16 -06:00
nbody = tuple(f"not {atom}" for atom in self.nbody)
2023-05-25 14:37:24 -06:00
body = ", ".join(self.pbody + nbody)
2023-05-24 19:05:16 -06:00
return f"{head} :- {body}."
@dataclass
class OFormula:
def text(self):
if self is OConst.TRUE:
return "TRUE"
elif self is OConst.FALSE:
return "FALSE"
raise Exception("Unknown text repr")
class OConst(Enum):
TRUE = OFormula()
FALSE = OFormula()
def text(self):
return OFormula.text(self)
@dataclass
class OBinary(OFormula):
operator: str
left: OFormula
right: OFormula
def text(self):
2023-05-25 14:37:24 -06:00
return f"({self.left.text()} {self.operator} {self.right.text()})"
2023-05-24 19:05:16 -06:00
@dataclass
class OAtom(OFormula):
name: str
def text(self):
return self.name
@dataclass
class ONegation(OFormula):
of: OFormula
def text(self):
return f"-{self.of.text()}"
@dataclass
class KB:
ont: OFormula
rules: tuple[LRule]
katoms: Set[atom]
oatoms: Set[atom]
def text(self):
ont = f"{self.ont.text()}.\n"
rules = "\n".join(map(LRule.text, self.rules))
return ont + rules
class HMKNFVisitor(ParseTreeVisitor):
def __init__(self) -> None:
self.katoms = set()
self.oatoms = set()
def visitKb(self, ctx: HMKNFParser.KbContext):
2023-05-25 14:37:24 -06:00
orules = (self.visit(orule) for orule in ctx.orule()) if ctx.orule() else (OConst.TRUE,)
ont = reduce(partial(OBinary, "&"), orules)
2023-05-24 19:05:16 -06:00
lrules = tuple(self.visit(lrule) for lrule in ctx.lrule())
return KB(ont, lrules, Set(self.katoms), Set(self.oatoms))
def visitLrule(self, ctx: HMKNFParser.LruleContext):
head = self.visit(ctx.head()) if ctx.head() else ()
pbody, nbody = self.visit(ctx.body()) if ctx.body() else ((), ())
return LRule(head, pbody, nbody)
def visitHead(self, ctx: HMKNFParser.HeadContext):
heads = tuple(str(atom) for atom in ctx.IDENTIFIER())
self.katoms.update(heads)
return heads
def visitBody(self, ctx: HMKNFParser.BodyContext):
nbody, pbody = partition(
itemgetter(1), (self.visit(body_atom) for body_atom in ctx.katom())
)
return tuple(tuple(item[0] for item in part) for part in (pbody, nbody))
def visitOrule(self, ctx: HMKNFParser.OruleContext):
return self.visit(ctx.oformula())
def visitDisjOrConj(self, ctx: HMKNFParser.DisjOrConjContext):
return OBinary(
ctx.operator.text, self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))
)
def visitParenth(self, ctx: HMKNFParser.ParenthContext):
return self.visit(ctx.oformula())
def visitImp(self, ctx: HMKNFParser.ImpContext):
return OBinary("->", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)))
def visitLiteral(self, ctx: HMKNFParser.LiteralContext):
return self.visit(ctx.oatom())
def visitKatom(self, ctx: HMKNFParser.KatomContext):
name, sign = self.visit(ctx.patom() if ctx.patom() else ctx.natom())
self.katoms.add(name)
return name, sign
def visitOatom(self, ctx: HMKNFParser.OatomContext):
name = str(ctx.IDENTIFIER())
self.oatoms.add(name)
return OAtom(name)
def visitPatom(self, ctx: HMKNFParser.PatomContext):
return str(ctx.IDENTIFIER()), True
def visitNatom(self, ctx: HMKNFParser.NatomContext):
return str(ctx.IDENTIFIER()), False
def visitNegation(self, ctx: HMKNFParser.NegationContext):
return ONegation(self.visit(ctx.oformula()))
def check_syntax_constraints(ast: KB):
for rule in ast.rules:
if len(rule.head) > 1:
raise Exception("No rule may have more than 1 atom in its head")
return ast
def unique_atom_namer(prefix: str):
for i in count(1):
for suffix in product(ascii_lowercase, repeat=i):
suffix = "".join(suffix)
yield prefix + suffix
def desugar_constraints(ast: KB):
used_names = list()
def add_name(name: str):
used_names.append(name)
return name
names = map(add_name, unique_atom_namer("constraint_"))
rules = tuple(
LRule(
(name := next(names),),
rule.pbody,
rule.nbody + (name,),
)
if not rule.head
else rule
for rule in ast.rules
)
return KB(ast.ont, rules, ast.katoms.union(used_names), ast.oatoms)
def loadProgram(fileContents) -> KB:
input_stream = InputStream(fileContents)
lexer = HMKNFLexer(input_stream)
stream = CommonTokenStream(lexer)
parser = HMKNFParser(stream)
tree = parser.kb()
visitor = HMKNFVisitor()
ast = visitor.visit(tree)
return desugar_constraints(check_syntax_constraints(ast))
def main():
if len(argv) > 1:
in_file = open(argv[1], "rt", encoding="utf8")
else:
in_file = stdin
print(loadProgram(in_file.read()).text())
if __name__ == "__main__" and not flags.interactive:
main()