154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
|
from AST import loadProgram, OFormula, LRule, OBinary, OLiteral, OConst, KB
|
||
|
from itertools import combinations
|
||
|
from sys import flags
|
||
|
|
||
|
# Expects total assignments
|
||
|
def oformula_sat(o: OFormula, total_T: set[str]):
|
||
|
match o:
|
||
|
case OConst.TRUE:
|
||
|
return True
|
||
|
case OConst.FALSE:
|
||
|
return False
|
||
|
case OBinary("|", left, right):
|
||
|
return oformula_sat(left, total_T) or oformula_sat(right, total_T)
|
||
|
case OBinary("&", left, right):
|
||
|
return oformula_sat(left, total_T) and oformula_sat(right, total_T)
|
||
|
case OBinary("->", left, right):
|
||
|
return (not oformula_sat(left, total_T)) or oformula_sat(right, total_T)
|
||
|
case OLiteral(name, is_positive):
|
||
|
return name in total_T if is_positive else name not in total_T
|
||
|
case _:
|
||
|
raise Exception("Ontology error")
|
||
|
|
||
|
|
||
|
def objective_knowledge(
|
||
|
kb: KB, T: set[str], F: set[str]
|
||
|
) -> tuple[bool, set[str], set[str]]:
|
||
|
undefined_atoms = kb.atoms.difference(T.union(F))
|
||
|
entailed = kb.atoms
|
||
|
entailed_false = kb.atoms
|
||
|
consistent = False
|
||
|
for num_true in range(len(undefined_atoms) + 1):
|
||
|
for I in combinations(undefined_atoms, num_true):
|
||
|
total = T.union(I)
|
||
|
if oformula_sat(kb.ont, total):
|
||
|
consistent = True
|
||
|
entailed = entailed.intersection(total)
|
||
|
entailed_false = entailed_false.intersection(kb.atoms.difference(total))
|
||
|
return consistent, T.union(entailed), F.union(entailed_false)
|
||
|
|
||
|
|
||
|
def Gamma_T(kb: KB, T: set[str], P: set[str]):
|
||
|
_, T, _ = objective_knowledge(kb, T, set())
|
||
|
for rule in kb.rules:
|
||
|
if not T.issuperset(rule.pbody):
|
||
|
continue
|
||
|
if P.intersection(rule.nbody):
|
||
|
continue
|
||
|
assert len(rule.head) == 1, "operator does not support disjunctive rules!"
|
||
|
T = T.union({rule.head[0]})
|
||
|
return T, P
|
||
|
|
||
|
|
||
|
def block(kb: KB, T: set[str], P: set[str]):
|
||
|
_, _, F = objective_knowledge(kb, T, set())
|
||
|
blocked = set()
|
||
|
for rule in kb.rules:
|
||
|
match rule:
|
||
|
case LRule((h,), pbody, nbody):
|
||
|
pass
|
||
|
case _:
|
||
|
continue
|
||
|
if h not in F:
|
||
|
continue
|
||
|
if P.intersection(nbody):
|
||
|
continue
|
||
|
for atom in pbody:
|
||
|
if T.issuperset(set(pbody).difference({atom})):
|
||
|
break
|
||
|
else:
|
||
|
continue
|
||
|
|
||
|
blocked.add(atom)
|
||
|
return F | blocked
|
||
|
|
||
|
|
||
|
def Gamma_P(kb: KB, T: set[str], P: set[str]):
|
||
|
Pprime, _ = Gamma_T(kb, P, T)
|
||
|
blocking = block(kb, T, P)
|
||
|
return T, Pprime.difference(blocking)
|
||
|
|
||
|
|
||
|
def fixpoint(op, start):
|
||
|
while (succ := op(start)) != start:
|
||
|
start = succ
|
||
|
return start
|
||
|
|
||
|
|
||
|
def fprint_partition(TP):
|
||
|
print("({", end="")
|
||
|
print(", ".join(TP[0]), end="")
|
||
|
print("}, {", end="")
|
||
|
print(", ".join(TP[1]), end="")
|
||
|
print("})")
|
||
|
|
||
|
|
||
|
def stable_revision(kb: KB, TP: tuple[set[str], set[str]]):
|
||
|
T, P = TP
|
||
|
fprint_partition(TP)
|
||
|
T, P = fixpoint(lambda T: Gamma_T(kb, T, P)[0], set()), fixpoint(
|
||
|
lambda P: Gamma_P(kb, T, P)[1], set()
|
||
|
)
|
||
|
return T, P
|
||
|
|
||
|
|
||
|
def ext_block(kb: KB, TP: tuple[set[str], set[str]]):
|
||
|
T, P = TP
|
||
|
F = kb.atoms.difference(P)
|
||
|
wl = [F]
|
||
|
while len(wl):
|
||
|
B = wl.pop(0)
|
||
|
consistent, _, _ = objective_knowledge(kb, T, F)
|
||
|
if consistent:
|
||
|
yield B
|
||
|
continue
|
||
|
wl.extend(
|
||
|
Bsmaller for Bsmaller in combinations(B, len(B) - 1) if Bsmaller not in wl
|
||
|
)
|
||
|
|
||
|
def ext_op(kb: KB, TP: tuple[set[str], set[str]]) -> tuple[set[str], set[str]]:
|
||
|
T, Pprime = TP
|
||
|
for B in ext_block(kb, TP):
|
||
|
consistent, _, Fprime = objective_knowledge(kb, T, B)
|
||
|
assert consistent
|
||
|
Pprime = Pprime.difference(Fprime)
|
||
|
|
||
|
return T, Pprime
|
||
|
|
||
|
|
||
|
def extended_stable_revision(kb: KB, ex_op, TP: tuple[set[str], set[str]]):
|
||
|
return ex_op(kb, stable_revision(kb, TP))
|
||
|
|
||
|
|
||
|
def least_stable_fixedpoint(kb: KB):
|
||
|
return fixpoint(lambda TP: stable_revision(kb, TP), (set(), kb.atoms))
|
||
|
|
||
|
|
||
|
def least_ext_stable_fixedpoint(kb: KB, ex_op):
|
||
|
return fixpoint(
|
||
|
lambda TP: extended_stable_revision(kb, ex_op, TP), (set(), kb.atoms)
|
||
|
)
|
||
|
|
||
|
|
||
|
def main():
|
||
|
from sys import stdin
|
||
|
|
||
|
kb = loadProgram(stdin.read())
|
||
|
wfm = least_stable_fixedpoint(kb)
|
||
|
|
||
|
fprint_partition(wfm)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__" and not flags.interactive:
|
||
|
main()
|