"""Implementation of DPLL algorithm Features: - Clause learning - Watch literal scheme - VSIDS heuristic References: - https://en.wikipedia.org/wiki/DPLL_algorithm """ from collections import defaultdict from heapq import heappush, heappop from sympy.core.sorting import ordered from sympy.assumptions.cnf import EncodedCNF def dpll_satisfiable(expr, all_models=False): """ Check satisfiability of a propositional sentence. It returns a model rather than True when it succeeds. Returns a generator of all models if all_models is True. Examples ======== >>> from sympy.abc import A, B >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable >>> dpll_satisfiable(A & ~B) {A: True, B: False} >>> dpll_satisfiable(A & ~A) False """ if not isinstance(expr, EncodedCNF): exprs = EncodedCNF() exprs.add_prop(expr) expr = exprs # Return UNSAT when False (encoded as 0) is present in the CNF if {0} in expr.data: if all_models: return (f for f in [False]) return False solver = SATSolver(expr.data, expr.variables, set(), expr.symbols) models = solver._find_model() if all_models: return _all_models(models) try: return next(models) except StopIteration: return False # Uncomment to confirm the solution is valid (hitting set for the clauses) #else: #for cls in clauses_int_repr: #assert solver.var_settings.intersection(cls) def _all_models(models): satisfiable = False try: while True: yield next(models) satisfiable = True except StopIteration: if not satisfiable: yield False class SATSolver: """ Class for representing a SAT solver capable of finding a model to a boolean theory in conjunctive normal form. """ def __init__(self, clauses, variables, var_settings, symbols=None, heuristic='vsids', clause_learning='none', INTERVAL=500): self.var_settings = var_settings self.heuristic = heuristic self.is_unsatisfied = False self._unit_prop_queue = [] self.update_functions = [] self.INTERVAL = INTERVAL if symbols is None: self.symbols = list(ordered(variables)) else: self.symbols = symbols self._initialize_variables(variables) self._initialize_clauses(clauses) if 'vsids' == heuristic: self._vsids_init() self.heur_calculate = self._vsids_calculate self.heur_lit_assigned = self._vsids_lit_assigned self.heur_lit_unset = self._vsids_lit_unset self.heur_clause_added = self._vsids_clause_added # Note: Uncomment this if/when clause learning is enabled #self.update_functions.append(self._vsids_decay) else: raise NotImplementedError if 'simple' == clause_learning: self.add_learned_clause = self._simple_add_learned_clause self.compute_conflict = self.simple_compute_conflict self.update_functions.append(self.simple_clean_clauses) elif 'none' == clause_learning: self.add_learned_clause = lambda x: None self.compute_conflict = lambda: None else: raise NotImplementedError # Create the base level self.levels = [Level(0)] self._current_level.varsettings = var_settings # Keep stats self.num_decisions = 0 self.num_learned_clauses = 0 self.original_num_clauses = len(self.clauses) def _initialize_variables(self, variables): """Set up the variable data structures needed.""" self.sentinels = defaultdict(set) self.occurrence_count = defaultdict(int) self.variable_set = [False] * (len(variables) + 1) def _initialize_clauses(self, clauses): """Set up the clause data structures needed. For each clause, the following changes are made: - Unit clauses are queued for propagation right away. - Non-unit clauses have their first and last literals set as sentinels. - The number of clauses a literal appears in is computed. """ self.clauses = [list(clause) for clause in clauses] for i, clause in enumerate(self.clauses): # Handle the unit clauses if 1 == len(clause): self._unit_prop_queue.append(clause[0]) continue self.sentinels[clause[0]].add(i) self.sentinels[clause[-1]].add(i) for lit in clause: self.occurrence_count[lit] += 1 def _find_model(self): """ Main DPLL loop. Returns a generator of models. Variables are chosen successively, and assigned to be either True or False. If a solution is not found with this setting, the opposite is chosen and the search continues. The solver halts when every variable has a setting. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> list(l._find_model()) [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}] >>> from sympy.abc import A, B, C >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set(), [A, B, C]) >>> list(l._find_model()) [{A: True, B: False, C: False}, {A: True, B: True, C: True}] """ # We use this variable to keep track of if we should flip a # variable setting in successive rounds flip_var = False # Check if unit prop says the theory is unsat right off the bat self._simplify() if self.is_unsatisfied: return # While the theory still has clauses remaining while True: # Perform cleanup / fixup at regular intervals if self.num_decisions % self.INTERVAL == 0: for func in self.update_functions: func() if flip_var: # We have just backtracked and we are trying to opposite literal flip_var = False lit = self._current_level.decision else: # Pick a literal to set lit = self.heur_calculate() self.num_decisions += 1 # Stopping condition for a satisfying theory if 0 == lit: yield {self.symbols[abs(lit) - 1]: lit > 0 for lit in self.var_settings} while self._current_level.flipped: self._undo() if len(self.levels) == 1: return flip_lit = -self._current_level.decision self._undo() self.levels.append(Level(flip_lit, flipped=True)) flip_var = True continue # Start the new decision level self.levels.append(Level(lit)) # Assign the literal, updating the clauses it satisfies self._assign_literal(lit) # _simplify the theory self._simplify() # Check if we've made the theory unsat if self.is_unsatisfied: self.is_unsatisfied = False # We unroll all of the decisions until we can flip a literal while self._current_level.flipped: self._undo() # If we've unrolled all the way, the theory is unsat if 1 == len(self.levels): return # Detect and add a learned clause self.add_learned_clause(self.compute_conflict()) # Try the opposite setting of the most recent decision flip_lit = -self._current_level.decision self._undo() self.levels.append(Level(flip_lit, flipped=True)) flip_var = True ######################## # Helper Methods # ######################## @property def _current_level(self): """The current decision level data structure Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{1}, {2}], {1, 2}, set()) >>> next(l._find_model()) {1: True, 2: True} >>> l._current_level.decision 0 >>> l._current_level.flipped False >>> l._current_level.var_settings {1, 2} """ return self.levels[-1] def _clause_sat(self, cls): """Check if a clause is satisfied by the current variable setting. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{1}, {-1}], {1}, set()) >>> try: ... next(l._find_model()) ... except StopIteration: ... pass >>> l._clause_sat(0) False >>> l._clause_sat(1) True """ for lit in self.clauses[cls]: if lit in self.var_settings: return True return False def _is_sentinel(self, lit, cls): """Check if a literal is a sentinel of a given clause. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> next(l._find_model()) {1: True, 2: False, 3: False} >>> l._is_sentinel(2, 3) True >>> l._is_sentinel(-3, 1) False """ return cls in self.sentinels[lit] def _assign_literal(self, lit): """Make a literal assignment. The literal assignment must be recorded as part of the current decision level. Additionally, if the literal is marked as a sentinel of any clause, then a new sentinel must be chosen. If this is not possible, then unit propagation is triggered and another literal is added to the queue to be set in the future. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> next(l._find_model()) {1: True, 2: False, 3: False} >>> l.var_settings {-3, -2, 1} >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l._assign_literal(-1) >>> try: ... next(l._find_model()) ... except StopIteration: ... pass >>> l.var_settings {-1} """ self.var_settings.add(lit) self._current_level.var_settings.add(lit) self.variable_set[abs(lit)] = True self.heur_lit_assigned(lit) sentinel_list = list(self.sentinels[-lit]) for cls in sentinel_list: if not self._clause_sat(cls): other_sentinel = None for newlit in self.clauses[cls]: if newlit != -lit: if self._is_sentinel(newlit, cls): other_sentinel = newlit elif not self.variable_set[abs(newlit)]: self.sentinels[-lit].remove(cls) self.sentinels[newlit].add(cls) other_sentinel = None break # Check if no sentinel update exists if other_sentinel: self._unit_prop_queue.append(other_sentinel) def _undo(self): """ _undo the changes of the most recent decision level. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> next(l._find_model()) {1: True, 2: False, 3: False} >>> level = l._current_level >>> level.decision, level.var_settings, level.flipped (-3, {-3, -2}, False) >>> l._undo() >>> level = l._current_level >>> level.decision, level.var_settings, level.flipped (0, {1}, False) """ # Undo the variable settings for lit in self._current_level.var_settings: self.var_settings.remove(lit) self.heur_lit_unset(lit) self.variable_set[abs(lit)] = False # Pop the level off the stack self.levels.pop() ######################### # Propagation # ######################### """ Propagation methods should attempt to soundly simplify the boolean theory, and return True if any simplification occurred and False otherwise. """ def _simplify(self): """Iterate over the various forms of propagation to simplify the theory. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.variable_set [False, False, False, False] >>> l.sentinels {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} >>> l._simplify() >>> l.variable_set [False, True, False, False] >>> l.sentinels {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3}, ...3: {2, 4}} """ changed = True while changed: changed = False changed |= self._unit_prop() changed |= self._pure_literal() def _unit_prop(self): """Perform unit propagation on the current theory.""" result = len(self._unit_prop_queue) > 0 while self._unit_prop_queue: next_lit = self._unit_prop_queue.pop() if -next_lit in self.var_settings: self.is_unsatisfied = True self._unit_prop_queue = [] return False else: self._assign_literal(next_lit) return result def _pure_literal(self): """Look for pure literals and assign them when found.""" return False ######################### # Heuristics # ######################### def _vsids_init(self): """Initialize the data structures needed for the VSIDS heuristic.""" self.lit_heap = [] self.lit_scores = {} for var in range(1, len(self.variable_set)): self.lit_scores[var] = float(-self.occurrence_count[var]) self.lit_scores[-var] = float(-self.occurrence_count[-var]) heappush(self.lit_heap, (self.lit_scores[var], var)) heappush(self.lit_heap, (self.lit_scores[-var], -var)) def _vsids_decay(self): """Decay the VSIDS scores for every literal. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.lit_scores {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} >>> l._vsids_decay() >>> l.lit_scores {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0} """ # We divide every literal score by 2 for a decay factor # Note: This doesn't change the heap property for lit in self.lit_scores.keys(): self.lit_scores[lit] /= 2.0 def _vsids_calculate(self): """ VSIDS Heuristic Calculation Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.lit_heap [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] >>> l._vsids_calculate() -3 >>> l.lit_heap [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)] """ if len(self.lit_heap) == 0: return 0 # Clean out the front of the heap as long the variables are set while self.variable_set[abs(self.lit_heap[0][1])]: heappop(self.lit_heap) if len(self.lit_heap) == 0: return 0 return heappop(self.lit_heap)[1] def _vsids_lit_assigned(self, lit): """Handle the assignment of a literal for the VSIDS heuristic.""" pass def _vsids_lit_unset(self, lit): """Handle the unsetting of a literal for the VSIDS heuristic. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.lit_heap [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] >>> l._vsids_lit_unset(2) >>> l.lit_heap [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1), ...(-2.0, 2), (0.0, 1)] """ var = abs(lit) heappush(self.lit_heap, (self.lit_scores[var], var)) heappush(self.lit_heap, (self.lit_scores[-var], -var)) def _vsids_clause_added(self, cls): """Handle the addition of a new clause for the VSIDS heuristic. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.num_learned_clauses 0 >>> l.lit_scores {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} >>> l._vsids_clause_added({2, -3}) >>> l.num_learned_clauses 1 >>> l.lit_scores {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0} """ self.num_learned_clauses += 1 for lit in cls: self.lit_scores[lit] += 1 ######################## # Clause Learning # ######################## def _simple_add_learned_clause(self, cls): """Add a new clause to the theory. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> l.num_learned_clauses 0 >>> l.clauses [[2, -3], [1], [3, -3], [2, -2], [3, -2]] >>> l.sentinels {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} >>> l._simple_add_learned_clause([3]) >>> l.clauses [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]] >>> l.sentinels {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}} """ cls_num = len(self.clauses) self.clauses.append(cls) for lit in cls: self.occurrence_count[lit] += 1 self.sentinels[cls[0]].add(cls_num) self.sentinels[cls[-1]].add(cls_num) self.heur_clause_added(cls) def _simple_compute_conflict(self): """ Build a clause representing the fact that at least one decision made so far is wrong. Examples ======== >>> from sympy.logic.algorithms.dpll2 import SATSolver >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, ... {3, -2}], {1, 2, 3}, set()) >>> next(l._find_model()) {1: True, 2: False, 3: False} >>> l._simple_compute_conflict() [3] """ return [-(level.decision) for level in self.levels[1:]] def _simple_clean_clauses(self): """Clean up learned clauses.""" pass class Level: """ Represents a single level in the DPLL algorithm, and contains enough information for a sound backtracking procedure. """ def __init__(self, decision, flipped=False): self.decision = decision self.var_settings = set() self.flipped = flipped