wfs-operator-experiment/AST.py

125 lines
3.6 KiB
Python

from antlr4 import ParseTreeVisitor, InputStream, CommonTokenStream
from grammars.HMKNFLexer import HMKNFLexer
from grammars.HMKNFParser import HMKNFParser
from dataclasses import dataclass
from more_itertools import partition
from operator import itemgetter
from functools import reduce, partial
from enum import Enum
@dataclass
class LRule:
head: tuple[str, ...]
pbody: tuple[str, ...]
nbody: tuple[str, ...]
@dataclass
class OFormula:
pass
class OConst(Enum):
TRUE = OFormula()
FALSE = OFormula()
@dataclass
class OBinary(OFormula):
operator: str
left: OFormula
right: OFormula
@dataclass
class OLiteral(OFormula):
name: str
sign_positive: bool
@dataclass
class KB:
ont: OFormula
rules: list[LRule]
atoms: set[str]
class HMKNFVisitor(ParseTreeVisitor):
def __init__(self) -> None:
self.atoms = set()
def visitKb(self, ctx: HMKNFParser.KbContext):
orules = (self.visit(orule) for orule in ctx.orule()) if ctx.orule() else ()
ont = reduce(partial(OBinary, "&"), orules, OConst.TRUE)
lrules = tuple(self.visit(lrule) for lrule in ctx.lrule()) if ctx.lrule() else ()
return KB(ont, lrules, self.atoms)
def visitLrule(self, ctx: HMKNFParser.LruleContext):
head = self.visit(ctx.head()) if ctx.head() else ()
pbody, nbody = self.visit(ctx.body()) if ctx.body() else ((), ())
return LRule(head, pbody, nbody)
def visitHead(self, ctx: HMKNFParser.HeadContext):
return tuple(str(atom) for atom in ctx.IDENTIFIER())
def visitBody(self, ctx: HMKNFParser.BodyContext):
nbody, pbody = partition(
itemgetter(1), (self.visit(body_atom) for body_atom in ctx.katom())
)
return tuple(tuple(item[0] for item in part) for part in (pbody, nbody))
def visitOrule(self, ctx: HMKNFParser.OruleContext):
return self.visit(ctx.oformula())
def visitDisj(self, ctx: HMKNFParser.DisjContext):
return OBinary("|", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)))
def visitConj(self, ctx: HMKNFParser.ConjContext):
return OBinary("&", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)))
def visitParenth(self, ctx: HMKNFParser.ParenthContext):
return self.visit(ctx.oformula())
def visitImp(self, ctx: HMKNFParser.ImpContext):
return OBinary("->", self.visit(ctx.oformula(0)), self.visit(ctx.oformula(1)))
def visitLiteral(self, ctx: HMKNFParser.LiteralContext):
return self.visit(ctx.oatom())
def visitKatom(self, ctx: HMKNFParser.KatomContext):
name, sign = self.visit(ctx.patom() if ctx.patom() else ctx.natom())
self.atoms.add(name)
return name, sign
def visitOatom(self, ctx: HMKNFParser.OatomContext):
name, sign_positive = self.visit(ctx.patom() if ctx.patom() else ctx.fatom())
self.atoms.add(name)
return OLiteral(name, sign_positive)
def visitPatom(self, ctx: HMKNFParser.PatomContext):
return str(ctx.IDENTIFIER()), True
def visitNatom(self, ctx: HMKNFParser.NatomContext):
return str(ctx.IDENTIFIER()), False
def visitFatom(self, ctx: HMKNFParser.FatomContext):
return str(ctx.IDENTIFIER()), False
def loadProgram(fileContents):
input_stream = InputStream(fileContents)
lexer = HMKNFLexer(input_stream)
stream = CommonTokenStream(lexer)
parser = HMKNFParser(stream)
tree = parser.kb()
visitor = HMKNFVisitor()
ast = visitor.visit(tree)
return ast
def main():
from sys import stdin
print(loadProgram(stdin.read()))
if __name__ == "__main__":
main()