clingo-hmknf-test/propagator.py

286 lines
9.7 KiB
Python

from itertools import chain
from typing import Iterable, Iterator, List, Set, Tuple
import clingo
from more_itertools import partition, unique_everseen
from clingo import (
PropagateControl,
PropagateInit,
PropagateControl,
PropagatorCheckMode,
Assignment,
Symbol,
)
from functools import partial
from sys import stderr
eprint = partial(print, file=stderr)
"""
API notes:
add_clause is disjunctive
add_nogood is conjunctive
defined appear in a rule head
No way to strict atoms except through adding nogoods/clauses?
"""
import ontology as O
class OntologyPropagator:
def init(self, init: PropagateInit) -> None:
init.check_mode = PropagatorCheckMode.Total
self.assignment = dict()
self.symbolic_atoms = {
init.solver_literal(atom.literal): str(atom.symbol)
for atom in init.symbolic_atoms
}
# Note that Clingo will (sometimes?) probably use the same solver literals for theory and symbolic atoms
# But storing them separately is safer
self.theory_atoms = {
init.solver_literal(atom.literal): atom.elements[0].terms[0].name
for atom in init.theory_atoms
}
self.symbolic_atoms_inv = {v: k for k, v in self.symbolic_atoms.items()}
self.theory_atoms_inv = {v: k for k, v in self.theory_atoms.items()}
# Make false always false
false_lit = self.theory_atoms_inv["false"]
init.add_clause([-false_lit])
# Might only need to watch just theory atoms / just symbol atoms but for now
# watching everything is easier
for lit in chain(self.symbolic_atoms, self.theory_atoms):
init.add_watch(lit)
# Could add these with additional rules, but I think that will change the semantics of the O atoms.
# An ontology atom must be true if it's regular counterpart is also true
# The opposite direction is already enforced by rules.
# for theory_atom in self.theory_atoms_inv:
# theory_lit = self.theory_atoms_inv[theory_atom]
# symbolic_lit = self.symbolic_atoms_inv[theory_atom]
# # This is already implied if the symbolic and theory literals are the same
# init.add_clause((-symbolic_lit, theory_lit))
assert len(set(self.symbolic_atoms) & set(self.theory_atoms)) == 0
def truthy_atoms_text(self):
return (
str(self.lookup_solver_lit(atom)[0])
for atom in self.assignment
if self.assignment[atom]
)
def atom_text_to_dl_atom(self, atoms: Iterable[str]) -> Iterable[int]:
return (
lit
for atom in atoms
if (lit := self.theory_atoms_inv.get(atom)) is not None
)
def falsey_atoms_text(self):
return (
str(self.lookup_solver_lit(atom)[0])
for atom in self.assignment
if not self.assignment[atom]
)
def print_nogood(self, nogood: Tuple[int, ...]):
eprint("adding nogood: ", end="")
names = (
self.symbolic_atoms.get(abs(lit), self.theory_atoms.get(abs(lit)))
for lit in nogood
)
eprint(
" ".join(
f"(not {name})" if lit < 0 else f"({name})"
for lit, name in zip(nogood, names)
)
)
def assign_nogood(self, pcontrol: PropagateControl, lits: Iterable[int]):
not_lits = set(-lit for lit in lits)
assert len(not_lits)
assignment = set(self.assignment_lits())
a = set(map(abs, not_lits))
b = set(map(abs, assignment))
assert len(a | b) == (len(a) + len(b))
nogood = tuple(chain(assignment, not_lits))
self.print_nogood(nogood)
pcontrol.add_nogood(nogood)
def assignment_lits(self):
return (lit if is_pos else -lit for lit, is_pos in self.assignment.items())
def conflict(self, pcontrol: PropagateControl):
eprint("conflict: ", end="")
self.print_assignment()
pcontrol.add_nogood(self.assignment_lits())
def propagate(self, pcontrol: PropagateControl, changes) -> None:
for change in changes:
atom = abs(change)
assert atom not in self.assignment
self.assignment[atom] = change >= 0
in_atoms = set(self.truthy_atoms_text())
eprint("made truthy: ", " ".join(in_atoms))
out_atoms = set(O.propagate(in_atoms))
# This is sort of special in that regular atoms won't propagate their ontology counterparts.
new_atoms = out_atoms - in_atoms
new_atoms = set(self.atom_text_to_dl_atom(new_atoms))
l = lambda atom: atom in self.assignment
new_atoms, prev_new_atoms = map(tuple, partition(l, new_atoms))
if any(not self.assignment[atom] for atom in prev_new_atoms):
self.conflict(pcontrol)
elif len(new_atoms):
self.assign_nogood(pcontrol, new_atoms)
else:
return
pcontrol.propagate()
eprint("propagate: ", end="")
self.print_assignment()
def undo(self, thread_id: int, assignment: Assignment, changes: List[int]):
for change in changes:
atom = abs(change)
del self.assignment[atom]
def lits_to_text(self, lits: Iterable[int]) -> Iterable[str]:
return (
self.symbolic_atoms.get(abs(lit), self.theory_atoms.get(abs(lit)))
for lit in lits
)
def check(self, pcontrol: PropagateControl) -> None:
eprint("check: ", end="")
self.print_assignment()
in_atoms = set(self.truthy_atoms_text())
shrink = O.check(in_atoms)
if shrink is None:
self.conflict(pcontrol)
return
# Theory atom might not be present if it was removed by clingo for some reason...
shrink = tuple(self.theory_atoms_inv[atom] for atom in shrink if atom in self.theory_atoms)
if any(self.assignment.get(abs(lit)) for lit in shrink):
self.conflict(pcontrol)
return
eprint("shrink with: ", " ".join(self.lits_to_text(shrink)))
for lit in shrink:
self.assign_nogood(pcontrol, (-lit,))
def print_assignment(self):
eprint("assignment: ", end="")
for lit in self.assignment:
lit = -lit if not self.assignment[lit] else lit
eprint(self.solver_lit_text(lit), end=" ")
eprint()
def is_theory_atom(self, lit: int):
_, _, a = self.lookup_solver_lit(lit)
return a
def solver_lit_text(self, lit: int):
symbol, is_neg, is_theory = self.lookup_solver_lit(lit)
if symbol is None:
return None
theory = "O: " if is_theory else ""
neg = "not " if is_neg else ""
return f"({theory}{neg}{symbol})"
def lookup_solver_lit(self, lit: int) -> Tuple[Symbol, bool, bool]:
atom = abs(lit)
if (atom_symb := self.symbolic_atoms.get(atom, None)) is not None:
return atom_symb, lit < 0, False
if (theo_symbol := self.theory_atoms.get(atom, None)) is not None:
return theo_symbol, lit < 0, True
return None, False, False
# Need to ground program before we can look up symbolic atoms and
# Having two waves of grounding might be having some weird effects
# So we parse and ground and then revert back to text then reparse
def add_external_atoms(program: str) -> str:
control = clingo.Control(["0"])
control.add("base", [], program.replace("\n", ""))
control.ground([("base", [])])
theory_grammar = """
#theory o {
kterm {- : 0, unary };
&o/0 : kterm, any
}.
"""
# Using .signatures here because symbols is unreliable
# E.g. a rule with the single rule `a :- a.` will not generate a symbol for an atom
# Dummy rules of the form atom :- &o{false} must be added for atoms that do not appear
# In the head of a rule as Clingo may decide to unify these with theory atoms for whatever reason
# This seems to be a fix for now
external_atoms = "\n".join(
f"{atom} :- &o{{{atom}}}. {atom} :- &o{{false}}."
for atom in (sig for sig, _, _ in control.symbolic_atoms.signatures)
)
return theory_grammar + program + "\n" + external_atoms
def solve(program: str, O_alphabet: Iterable[str], O_models: Iterable[Iterable[str]]):
O.set_models(O_alphabet, O_models)
program = add_external_atoms(program)
eprint("USING ONTOLOGY WITH ALPHABET AND MODELS:")
eprint(O_alphabet)
eprint(O_models)
eprint("USING PROGRAM:")
eprint(program)
control = clingo.Control(["0"])
propagator = OntologyPropagator()
control.register_propagator(propagator)
control.add("base", [], program)
control.ground([("base", [])])
answer_sets = []
with control.solve(yield_=True) as solve_handle:
for model in solve_handle:
eprint("answer set:", model)
answer_sets.append(str(model))
eprint()
if len(answer_sets):
print("ALL FINISHED, ALL ANSWER SETS:")
print(*unique_everseen(answer_sets), sep="\n")
else:
print("ALL DONE, NO ANSWER SETS")
def main():
from sys import argv
from os.path import splitext
from logic import logic
assert len(argv) == 2, "Please provide an .lp file as an argument"
lp_filename = argv[1]
models_filename = splitext(lp_filename)[0] + ".ont"
with open(lp_filename, "rt", encoding="utf8") as lp_fo:
with open(models_filename, "rt", encoding="utf8") as models_fo:
models_text = models_fo.read()
formula = logic.parse(models_text)
models = tuple(formula.models())
solve(lp_fo.read(), formula.alphabet, models)
if __name__ == "__main__":
main()