from AST import loadProgram, OFormula, LRule, OBinary, OLiteral, OConst, KB from itertools import combinations from sys import flags # TODO TODO TODO # separate O atoms from KA atoms? do we care though? # Expects total assignments def oformula_sat(o: OFormula, total_T: set[str]): match o: case OConst.TRUE: return True case OConst.FALSE: return False case OBinary("|", left, right): return oformula_sat(left, total_T) or oformula_sat(right, total_T) case OBinary("&", left, right): return oformula_sat(left, total_T) and oformula_sat(right, total_T) case OBinary("->", left, right): return (not oformula_sat(left, total_T)) or oformula_sat(right, total_T) case OLiteral(name, is_positive): return name in total_T if is_positive else name not in total_T case _: raise Exception("Ontology error") def objective_knowledge( kb: KB, T: set[str], F: set[str] ) -> tuple[bool, set[str], set[str]]: undefined_atoms = kb.atoms.difference(T.union(F)) entailed = kb.atoms entailed_false = kb.atoms consistent = False for num_true in range(len(undefined_atoms) + 1): for I in combinations(undefined_atoms, num_true): total = T.union(I) if oformula_sat(kb.ont, total): consistent = True entailed = entailed.intersection(total) entailed_false = entailed_false.intersection(kb.atoms.difference(total)) return consistent, T.union(entailed), F.union(entailed_false) def Gamma_T(kb: KB, T: set[str], P: set[str]): _, T, _ = objective_knowledge(kb, T, set()) for rule in kb.rules: if not T.issuperset(rule.pbody): continue if P.intersection(rule.nbody): continue assert len(rule.head) == 1, "operator does not support disjunctive rules!" T = T.union({rule.head[0]}) return T, P def block(kb: KB, T: set[str], P: set[str]): _, _, F = objective_knowledge(kb, T, set()) blocked = set() for rule in kb.rules: match rule: case LRule((h,), pbody, nbody): pass case _: continue if h not in F: continue # Can't be P or T? TODO if P.intersection(nbody): continue for atom in pbody: if T.issuperset(set(pbody).difference({atom})): break else: continue blocked.add(atom) return F | blocked def Gamma_P(kb: KB, T: set[str], P: set[str]): Pprime, _ = Gamma_T(kb, P, T) blocking = block(kb, T, P) return T, Pprime.difference(blocking) def fixpoint(op, start): while (succ := op(start)) != start: start = succ return start def fprint_partition(TP): print("({", end="") print(", ".join(TP[0]), end="") print("}, {", end="") print(", ".join(TP[1]), end="") print("})") def stable_revision(kb: KB, TP: tuple[set[str], set[str]]): T, P = TP fprint_partition(TP) T, P = fixpoint(lambda T: Gamma_T(kb, T, P)[0], set()), fixpoint( lambda P: Gamma_P(kb, T, P)[1], set() ) return T, P def ext_block(kb: KB, TP: tuple[set[str], set[str]]): T, P = TP F = kb.atoms.difference(P) wl = [F] while len(wl): B = wl.pop(0) consistent, _, _ = objective_knowledge(kb, T, F) if consistent: yield B continue wl.extend( Bsmaller for Bsmaller in combinations(B, len(B) - 1) if Bsmaller not in wl ) def ext_op(kb: KB, TP: tuple[set[str], set[str]]) -> tuple[set[str], set[str]]: T, Pprime = TP for B in ext_block(kb, TP): consistent, _, Fprime = objective_knowledge(kb, T, B) assert consistent Pprime = Pprime.difference(Fprime) return T, Pprime def extended_stable_revision(kb: KB, ex_op, TP: tuple[set[str], set[str]]): return ex_op(kb, stable_revision(kb, TP)) def least_stable_fixedpoint(kb: KB): return fixpoint(lambda TP: stable_revision(kb, TP), (set(), kb.atoms)) def least_ext_stable_fixedpoint(kb: KB, ex_op): return fixpoint( lambda TP: extended_stable_revision(kb, ex_op, TP), (set(), kb.atoms) ) def main(): from sys import stdin kb = loadProgram(stdin.read()) wfm = least_stable_fixedpoint(kb) fprint_partition(wfm) kb = loadProgram(open("tests/uh_oh2.in").read()) least_ext_stable_fixedpoint(kb, ext_op) # if __name__ == "__main__" and not flags.interactive: # main()