wfs-operator-experiment/hmknf.py

154 lines
4.3 KiB
Python

from AST import loadProgram, OFormula, LRule, OBinary, OLiteral, OConst, KB
from itertools import combinations
from sys import flags
# Expects total assignments
def oformula_sat(o: OFormula, total_T: set[str]):
match o:
case OConst.TRUE:
return True
case OConst.FALSE:
return False
case OBinary("|", left, right):
return oformula_sat(left, total_T) or oformula_sat(right, total_T)
case OBinary("&", left, right):
return oformula_sat(left, total_T) and oformula_sat(right, total_T)
case OBinary("->", left, right):
return (not oformula_sat(left, total_T)) or oformula_sat(right, total_T)
case OLiteral(name, is_positive):
return name in total_T if is_positive else name not in total_T
case _:
raise Exception("Ontology error")
def objective_knowledge(
kb: KB, T: set[str], F: set[str]
) -> tuple[bool, set[str], set[str]]:
undefined_atoms = kb.atoms.difference(T.union(F))
entailed = kb.atoms
entailed_false = kb.atoms
consistent = False
for num_true in range(len(undefined_atoms) + 1):
for I in combinations(undefined_atoms, num_true):
total = T.union(I)
if oformula_sat(kb.ont, total):
consistent = True
entailed = entailed.intersection(total)
entailed_false = entailed_false.intersection(kb.atoms.difference(total))
return consistent, T.union(entailed), F.union(entailed_false)
def Gamma_T(kb: KB, T: set[str], P: set[str]):
_, T, _ = objective_knowledge(kb, T, set())
for rule in kb.rules:
if not T.issuperset(rule.pbody):
continue
if P.intersection(rule.nbody):
continue
assert len(rule.head) == 1, "operator does not support disjunctive rules!"
T = T.union({rule.head[0]})
return T, P
def block(kb: KB, T: set[str], P: set[str]):
_, _, F = objective_knowledge(kb, T, set())
blocked = set()
for rule in kb.rules:
match rule:
case LRule((h,), pbody, nbody):
pass
case _:
continue
if h not in F:
continue
if P.intersection(nbody):
continue
for atom in pbody:
if T.issuperset(set(pbody).difference({atom})):
break
else:
continue
blocked.add(atom)
return F | blocked
def Gamma_P(kb: KB, T: set[str], P: set[str]):
Pprime, _ = Gamma_T(kb, P, T)
blocking = block(kb, T, P)
return T, Pprime.difference(blocking)
def fixpoint(op, start):
while (succ := op(start)) != start:
start = succ
return start
def fprint_partition(TP):
print("({", end="")
print(", ".join(TP[0]), end="")
print("}, {", end="")
print(", ".join(TP[1]), end="")
print("})")
def stable_revision(kb: KB, TP: tuple[set[str], set[str]]):
T, P = TP
fprint_partition(TP)
T, P = fixpoint(lambda T: Gamma_T(kb, T, P)[0], set()), fixpoint(
lambda P: Gamma_P(kb, T, P)[1], set()
)
return T, P
def ext_block(kb: KB, TP: tuple[set[str], set[str]]):
T, P = TP
F = kb.atoms.difference(P)
wl = [F]
while len(wl):
B = wl.pop(0)
consistent, _, _ = objective_knowledge(kb, T, F)
if consistent:
yield B
continue
wl.extend(
Bsmaller for Bsmaller in combinations(B, len(B) - 1) if Bsmaller not in wl
)
def ext_op(kb: KB, TP: tuple[set[str], set[str]]) -> tuple[set[str], set[str]]:
T, Pprime = TP
for B in ext_block(kb, TP):
consistent, _, Fprime = objective_knowledge(kb, T, B)
assert consistent
Pprime = Pprime.difference(Fprime)
return T, Pprime
def extended_stable_revision(kb: KB, ex_op, TP: tuple[set[str], set[str]]):
return ex_op(kb, stable_revision(kb, TP))
def least_stable_fixedpoint(kb: KB):
return fixpoint(lambda TP: stable_revision(kb, TP), (set(), kb.atoms))
def least_ext_stable_fixedpoint(kb: KB, ex_op):
return fixpoint(
lambda TP: extended_stable_revision(kb, ex_op, TP), (set(), kb.atoms)
)
def main():
from sys import stdin
kb = loadProgram(stdin.read())
wfm = least_stable_fixedpoint(kb)
fprint_partition(wfm)
if __name__ == "__main__" and not flags.interactive:
main()