219 lines
5.7 KiB
Python
219 lines
5.7 KiB
Python
"""
|
|
Parse a knowledge base and generate an AST
|
|
Can be run directly with program as standard input to view/debug AST
|
|
|
|
usage: python AST.py knowledge_bases/simple.hmknf
|
|
OR
|
|
usage: python AST.py < knowledge_bases/simple.hmknf
|
|
"""
|
|
|
|
|
|
from itertools import count, product
|
|
from string import ascii_lowercase
|
|
from sys import argv, flags, stdin
|
|
from antlr4 import ParseTreeVisitor, InputStream, CommonTokenStream
|
|
from grammars.HMKNFLexer import HMKNFLexer
|
|
from grammars.HMKNFParser import HMKNFParser
|
|
from dataclasses import dataclass
|
|
from more_itertools import partition
|
|
from operator import itemgetter
|
|
from functools import reduce, partial
|
|
from enum import Enum
|
|
|
|
atom = str
|
|
Set = frozenset
|
|
|
|
|
|
@dataclass
|
|
class LRule:
|
|
head: tuple[atom, ...]
|
|
pbody: tuple[atom, ...]
|
|
nbody: tuple[atom, ...]
|
|
|
|
def text(self):
|
|
head = ", ".join(self.head)
|
|
nbody = tuple(f"not {atom}" for atom in self.nbody)
|
|
body = ", ".join(self.pbody + nbody)
|
|
return f"{head} :- {body}."
|
|
|
|
|
|
@dataclass
|
|
class OFormula:
|
|
def text(self):
|
|
if self is OConst.TRUE:
|
|
return "TRUE"
|
|
elif self is OConst.FALSE:
|
|
return "FALSE"
|
|
raise Exception("Unknown text repr")
|
|
|
|
|
|
class OConst(Enum):
|
|
TRUE = OFormula()
|
|
FALSE = OFormula()
|
|
|
|
def text(self):
|
|
return OFormula.text(self)
|
|
|
|
|
|
@dataclass
|
|
class OBinary(OFormula):
|
|
operator: str
|
|
left: OFormula
|
|
right: OFormula
|
|
|
|
def text(self):
|
|
return f"({self.left.text()} {self.operator} {self.right.text()})"
|
|
|
|
|
|
@dataclass
|
|
class OAtom(OFormula):
|
|
name: str
|
|
|
|
def text(self):
|
|
return self.name
|
|
|
|
|
|
@dataclass
|
|
class ONegation(OFormula):
|
|
of: OFormula
|
|
|
|
def text(self):
|
|
return f"-{self.of.text()}"
|
|
|
|
|
|
@dataclass
|
|
class KB:
|
|
ont: OFormula
|
|
rules: tuple[LRule]
|
|
katoms: Set[atom]
|
|
oatoms: Set[atom]
|
|
|
|
def text(self):
|
|
ont = f"{self.ont.text()}.\n"
|
|
rules = "\n".join(map(LRule.text, self.rules))
|
|
return ont + rules
|
|
|
|
|
|
class HMKNFVisitor(ParseTreeVisitor):
|
|
def __init__(self) -> None:
|
|
self.katoms = set()
|
|
self.oatoms = set()
|
|
|
|
def visitKb(self, ctx: HMKNFParser.KbContext):
|
|
orules = (self.visit(orule) for orule in ctx.orule()) if ctx.orule() else (OConst.TRUE,)
|
|
ont = reduce(partial(OBinary, "&"), orules)
|
|
lrules = tuple(self.visit(lrule) for lrule in ctx.lrule())
|
|
return KB(ont, lrules, Set(self.katoms), Set(self.oatoms))
|
|
|
|
def visitLrule(self, ctx: HMKNFParser.LruleContext):
|
|
head = self.visit(ctx.head()) if ctx.head() else ()
|
|
pbody, nbody = self.visit(ctx.body()) if ctx.body() else ((), ())
|
|
return LRule(head, pbody, nbody)
|
|
|
|
def visitHead(self, ctx: HMKNFParser.HeadContext):
|
|
heads = tuple(str(atom) for atom in ctx.IDENTIFIER())
|
|
self.katoms.update(heads)
|
|
return heads
|
|
|
|
def visitBody(self, ctx: HMKNFParser.BodyContext):
|
|
nbody, pbody = partition(
|
|
itemgetter(1), (self.visit(body_atom) for body_atom in ctx.katom())
|
|
)
|
|
return tuple(tuple(item[0] for item in part) for part in (pbody, nbody))
|
|
|
|
def visitOrule(self, ctx: HMKNFParser.OruleContext):
|
|
return self.visit(ctx.oformula())
|
|
|
|
def visitDisjOrConj(self, ctx: HMKNFParser.DisjOrConjContext):
|
|
return OBinary(
|
|
ctx.operator.text, self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1))
|
|
)
|
|
|
|
def visitParenth(self, ctx: HMKNFParser.ParenthContext):
|
|
return self.visit(ctx.oformula())
|
|
|
|
def visitImp(self, ctx: HMKNFParser.ImpContext):
|
|
return OBinary("->", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)))
|
|
|
|
def visitLiteral(self, ctx: HMKNFParser.LiteralContext):
|
|
return self.visit(ctx.oatom())
|
|
|
|
def visitKatom(self, ctx: HMKNFParser.KatomContext):
|
|
name, sign = self.visit(ctx.patom() if ctx.patom() else ctx.natom())
|
|
self.katoms.add(name)
|
|
return name, sign
|
|
|
|
def visitOatom(self, ctx: HMKNFParser.OatomContext):
|
|
name = str(ctx.IDENTIFIER())
|
|
self.oatoms.add(name)
|
|
return OAtom(name)
|
|
|
|
def visitPatom(self, ctx: HMKNFParser.PatomContext):
|
|
return str(ctx.IDENTIFIER()), True
|
|
|
|
def visitNatom(self, ctx: HMKNFParser.NatomContext):
|
|
return str(ctx.IDENTIFIER()), False
|
|
|
|
def visitNegation(self, ctx: HMKNFParser.NegationContext):
|
|
return ONegation(self.visit(ctx.oformula()))
|
|
|
|
|
|
def check_syntax_constraints(ast: KB):
|
|
for rule in ast.rules:
|
|
if len(rule.head) > 1:
|
|
raise Exception("No rule may have more than 1 atom in its head")
|
|
return ast
|
|
|
|
|
|
def unique_atom_namer(prefix: str):
|
|
for i in count(1):
|
|
for suffix in product(ascii_lowercase, repeat=i):
|
|
suffix = "".join(suffix)
|
|
yield prefix + suffix
|
|
|
|
|
|
def desugar_constraints(ast: KB):
|
|
used_names = list()
|
|
|
|
def add_name(name: str):
|
|
used_names.append(name)
|
|
return name
|
|
|
|
names = map(add_name, unique_atom_namer("constraint_"))
|
|
rules = tuple(
|
|
LRule(
|
|
(name := next(names),),
|
|
rule.pbody,
|
|
rule.nbody + (name,),
|
|
)
|
|
if not rule.head
|
|
else rule
|
|
for rule in ast.rules
|
|
)
|
|
return KB(ast.ont, rules, ast.katoms.union(used_names), ast.oatoms)
|
|
|
|
|
|
def loadProgram(fileContents) -> KB:
|
|
input_stream = InputStream(fileContents)
|
|
lexer = HMKNFLexer(input_stream)
|
|
stream = CommonTokenStream(lexer)
|
|
parser = HMKNFParser(stream)
|
|
tree = parser.kb()
|
|
|
|
visitor = HMKNFVisitor()
|
|
ast = visitor.visit(tree)
|
|
return desugar_constraints(check_syntax_constraints(ast))
|
|
|
|
|
|
def main():
|
|
if len(argv) > 1:
|
|
in_file = open(argv[1], "rt", encoding="utf8")
|
|
else:
|
|
in_file = stdin
|
|
|
|
print(loadProgram(in_file.read()).text())
|
|
|
|
|
|
if __name__ == "__main__" and not flags.interactive:
|
|
main()
|