clingo-hmknf-test/propagator.py

186 lines
5.7 KiB
Python

from copy import deepcopy
from itertools import chain
from operator import getitem
from typing import Iterable, List, Tuple
from functools import partial
import clingo
from clingo import (
PropagateControl,
PropagateInit,
PropagateControl,
PropagatorCheckMode,
Assignment,
Symbol,
Control,
TheoryAtom,
)
"""
API notes:
add_clause is disjunctive
add_nogood is conjunctive
defined appear in a rule head
No way to strict atoms except through adding nogoods/clauses?
"""
import ontology as O
class Ontology:
def init(self, init: PropagateInit) -> None:
init.check_mode = PropagatorCheckMode.Total
self.assignment = dict()
self.symbolic_atoms = {
init.solver_literal(atom.literal): str(atom.symbol)
for atom in init.symbolic_atoms
}
self.theory_atoms = {
init.solver_literal(atom.literal): atom.elements[0].terms[0].name
for atom in init.theory_atoms
}
self.symbolic_atoms_inv = {v: k for k, v in self.symbolic_atoms.items()}
self.theory_atoms_inv = {v: k for k, v in self.theory_atoms.items()}
# Might only need to watch just theory atoms / just symbol atoms but for now
# watching everything is easier
for lit in chain(self.symbolic_atoms, self.theory_atoms):
init.add_watch(lit)
# Could add these with additional rules, but I think that will change the semantics of the O atoms.
# An ontology atom must be true if it's regular counterpart is also true
# The opposite direction is already enforced by rules.
for theory_atom in self.theory_atoms_inv:
theory_lit = self.theory_atoms_inv[theory_atom]
symbolic_lit = self.symbolic_atoms_inv[theory_atom]
init.add_clause((-symbolic_lit, theory_lit))
def truthy_atoms_text(self):
return (
str(self.lookup_solver_lit(atom)[0])
for atom in self.assignment
if self.assignment[atom]
)
def atom_text_to_dl_atom(self, atoms: Iterable[str]) -> Iterable[int]:
return (
lit
for atom in atoms
if (lit := self.theory_atoms_inv.get(atom)) is not None
)
def falsey_atoms_text(self):
return (
str(self.lookup_solver_lit(atom)[0])
for atom in self.assignment
if not self.assignment[atom]
)
def assign_nogood_true(self, pcontrol: PropagateControl, lits: Iterable[int]):
not_lits = (-lit for lit in lits)
assignment = (
lit if is_pos else -lit
for lit, is_pos in chain(
self.symbolic_atoms.items(), self.theory_atoms.items()
)
)
pcontrol.add_nogood(chain(assignment, not_lits))
def propagate(self, pcontrol: PropagateControl, changes) -> None:
print("propagate: ", end="")
self.print_assignment()
for change in changes:
atom = abs(change)
assert atom not in self.assignment
self.assignment[atom] = change >= 0
in_atoms = set(self.truthy_atoms_text())
out_atoms = set(self.atom_text_to_dl_atom(O.propagate(in_atoms)))
new_atoms = out_atoms - in_atoms
self.assign_nogood_true(pcontrol, new_atoms)
pcontrol.propagate()
def undo(self, thread_id: int, assignment: Assignment, changes: List[int]):
for change in changes:
atom = abs(change)
del self.assignment[atom]
def check(self, pcontrol: PropagateControl) -> None:
print("check: ", end="")
self.print_assignment()
def print_assignment(self):
print("assignment: ", end="")
for lit in self.assignment:
lit = -lit if not self.assignment[lit] else lit
print(self.solver_lit_text(lit), end=" ")
print()
def is_theory_atom(self, lit: int):
_, _, a = self.lookup_solver_lit(lit)
return a
def solver_lit_text(self, lit: int):
symbol, is_neg, is_theory = self.lookup_solver_lit(lit)
if symbol is None:
return None
theory = "O: " if is_theory else ""
neg = "not " if is_neg else ""
return f"({theory}{neg}{symbol})"
def lookup_solver_lit(self, lit: int) -> Tuple[Symbol, bool, bool]:
atom = abs(lit)
if (atom_symb := self.symbolic_atoms.get(atom, None)) is not None:
return atom_symb, lit < 0, False
if (theo_symbol := self.theory_atoms.get(atom, None)) is not None:
return theo_symbol, lit < 0, True
return None, False, False
program = """
a :- not b.
b :- not a.
"""
# Need to ground program before we can look up symbolic atoms and
# Having two waves of grounding might be having some weird effects
# So we parse and ground and then revert back to text then reparse
def add_external_atoms(program: str) -> str:
control = clingo.Control(["0"])
control.add("base", [], program.replace("\n", ""))
control.ground([("base", [])])
theory_grammar = """
#theory o {
kterm {- : 0, unary };
&o/0 : kterm, any
}.
"""
external_atoms = "\n".join(
f"{atom} :- &o{{{atom}}}."
for atom in (str(atom.symbol) for atom in control.symbolic_atoms)
)
return theory_grammar + program + external_atoms
program = add_external_atoms(program)
print(program)
control = clingo.Control(["0"])
propagator = Ontology()
control.register_propagator(propagator)
control.add("base", [], program)
control.ground([("base", [])])
# control.add("external", [], theory_grammar + external_atoms)
# control.ground([("external", [])])
with control.solve(yield_=True) as solve_handle:
for model in solve_handle:
print("answer set:", model)