Source code for boofun.core.query_model

"""
Query Model for Boolean Functions.

This module formalizes the distinction between:
1. EXPLICIT functions - we have the full truth table in memory
2. QUERY-ACCESS functions - we can only evaluate f(x) on demand

This is CRITICAL for production safety. A user should be able to:
    f = bf.create(massive_neural_network, n=1000)
    f.is_linear(num_tests=100)  # SAFE: only 300 queries

Without the library trying to compute 2^1000 entries.

Design Philosophy:
    "Never assume we can enumerate the domain"

    The library should work correctly even if the function represents:
    - A billion-parameter neural network
    - A database lookup
    - An external API call
    - A physical measurement device

    For such functions, only QUERY-BASED algorithms are valid.
"""

from __future__ import annotations

import warnings
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Dict, Optional

if TYPE_CHECKING:
    from .base import BooleanFunction

__all__ = [
    "QueryModel",
    "AccessType",
    "get_access_type",
    "check_query_safety",
    "QuerySafetyWarning",
    "ExplicitEnumerationError",
    "QUERY_COMPLEXITY",
]


[docs] class AccessType(Enum): """How can we access the function's values?""" EXPLICIT = auto() # Full truth table in memory - can enumerate QUERY = auto() # Can only evaluate f(x) - cannot enumerate safely STREAMING = auto() # Can iterate once but not random access SYMBOLIC = auto() # Have formula but may be expensive to evaluate all
[docs] class QuerySafetyWarning(UserWarning): """Warning when an operation may be unsafe for query-access functions."""
[docs] class ExplicitEnumerationError(RuntimeError): """ Raised when trying to enumerate a query-access function. This protects users from accidentally trying to compute 2^n evaluations on a huge function. """
# Query complexity for each operation # Format: (base_queries, scaling) # Total queries ≈ base_queries * scaling(n) QUERY_COMPLEXITY: Dict[str, Dict[str, Any]] = { # SAFE operations - O(k) queries where k is user-specified "is_linear": {"safe": True, "queries": lambda n, k: 3 * k, "description": "BLR test"}, "is_monotone": {"safe": True, "queries": lambda n, k: 2 * k, "description": "Edge test"}, "is_symmetric": { "safe": True, "queries": lambda n, k: 2 * k, "description": "Permutation test", }, "is_balanced_approx": {"safe": True, "queries": lambda n, k: k, "description": "Sample mean"}, "evaluate": {"safe": True, "queries": lambda n, k: 1, "description": "Single query"}, "estimate_fourier": { "safe": True, "queries": lambda n, k: k, "description": "Sample estimator", }, "goldreich_levin": {"safe": True, "queries": lambda n, k: k, "description": "GL algorithm"}, # UNSAFE operations - O(2^n) or worse "fourier": {"safe": False, "queries": lambda n, k: 2**n, "description": "Full WHT"}, "influences": { "safe": False, "queries": lambda n, k: n * 2**n, "description": "All inputs, all flips", }, "degree": {"safe": False, "queries": lambda n, k: 2**n, "description": "Uses fourier"}, "total_influence": { "safe": False, "queries": lambda n, k: n * 2**n, "description": "Uses influences", }, "W": {"safe": False, "queries": lambda n, k: 2**n, "description": "Uses fourier"}, "W_leq": {"safe": False, "queries": lambda n, k: 2**n, "description": "Uses fourier"}, "sparsity": {"safe": False, "queries": lambda n, k: 2**n, "description": "Uses fourier"}, "is_balanced": {"safe": False, "queries": lambda n, k: 2**n, "description": "Count all"}, "is_junta": {"safe": False, "queries": lambda n, k: n * 2**n, "description": "Uses influences"}, "fix": {"safe": False, "queries": lambda n, k: 2 ** (n - 1), "description": "Builds new table"}, "derivative": {"safe": False, "queries": lambda n, k: 2**n, "description": "Builds new table"}, "constant_test": { "safe": False, "queries": lambda n, k: 2**n, "description": "Exhaustive check", }, "decision_tree_depth": { "safe": False, "queries": lambda n, k: 3**n, "description": "DP over subcubes", }, "get_representation:truth_table": { "safe": False, "queries": lambda n, k: 2**n, "description": "Materialize", }, }
[docs] def get_access_type(f: "BooleanFunction") -> AccessType: """ Determine how the function's values can be accessed. Args: f: BooleanFunction to check Returns: AccessType indicating safest access pattern """ if f is None or f.n_vars is None: return AccessType.QUERY # Assume most restrictive # Check what representations we have reps = set(f.representations.keys()) # If we have truth table, we're explicit if "truth_table" in reps: return AccessType.EXPLICIT # If we have symbolic formula, depends on complexity if "symbolic" in reps or "anf" in reps: return AccessType.SYMBOLIC # If we only have a function callable, it's query-access if "function" in reps: return AccessType.QUERY # If we have Fourier coefficients, we can reconstruct if "fourier_expansion" in reps: return AccessType.EXPLICIT return AccessType.QUERY
[docs] def check_query_safety( f: "BooleanFunction", operation: str, max_safe_n: int = 20, num_queries: int = 100, strict: bool = False, ) -> bool: """ Check if an operation is safe to perform on this function. Args: f: BooleanFunction to check operation: Name of operation (e.g., "fourier", "is_linear") max_safe_n: Maximum n for which we allow unsafe operations num_queries: Number of queries for safe operations strict: If True, raise error instead of warning Returns: True if operation is safe to proceed Raises: ExplicitEnumerationError: If strict=True and operation is unsafe Example: >>> f = bf.create(huge_function, n=100) >>> check_query_safety(f, "fourier") # Returns False, warns >>> check_query_safety(f, "is_linear") # Returns True """ n = f.n_vars or 0 access_type = get_access_type(f) # Get operation info op_info = QUERY_COMPLEXITY.get(operation, {"safe": False, "queries": lambda n, k: 2**n}) is_safe = op_info["safe"] query_count = op_info["queries"](n, num_queries) # If operation is safe (query-based), always allow if is_safe: return True # If we have explicit representation, allow up to max_safe_n if access_type == AccessType.EXPLICIT: if n <= max_safe_n: return True else: msg = ( f"Operation '{operation}' requires {query_count:,} queries for n={n}. " f"This may be slow. Use query-based alternatives or reduce n." ) if strict: raise ExplicitEnumerationError(msg) warnings.warn(msg, QuerySafetyWarning) return True # Still allow but warn # Query-access function with unsafe operation if access_type == AccessType.QUERY: if n > max_safe_n: msg = ( f"UNSAFE: Operation '{operation}' would require ~{query_count:,} queries " f"on a query-access function with n={n}. " f"This is likely impossible (2^{n} evaluations). " f"Use query-based alternatives like estimate_fourier() or is_linear()." ) if strict: raise ExplicitEnumerationError(msg) warnings.warn(msg, QuerySafetyWarning) return False # Don't allow else: # Small n, we can convert to truth table return True return True
[docs] class QueryModel: """ Manages query complexity and safety for Boolean function operations. This class helps users understand and control the computational cost of operations on their functions. Example: >>> f = bf.create(my_function, n=30) >>> qm = QueryModel(f) >>> qm.can_compute("fourier") False >>> qm.can_compute("is_linear", num_queries=100) True >>> qm.estimate_cost("influences") {'queries': 32212254720, 'feasible': False} """
[docs] def __init__(self, f: "BooleanFunction", max_queries: int = 10_000_000): """ Initialize query model for a function. Args: f: BooleanFunction to analyze max_queries: Maximum acceptable query count """ self.f = f self.n = f.n_vars or 0 self.max_queries = max_queries self.access_type = get_access_type(f)
[docs] def can_compute(self, operation: str, **kwargs) -> bool: """Check if operation is computationally feasible.""" cost = self.estimate_cost(operation, **kwargs) return cost["feasible"]
[docs] def estimate_cost(self, operation: str, num_queries: int = 100) -> Dict[str, Any]: """ Estimate computational cost of an operation. Returns: Dict with keys: - queries: Estimated number of function evaluations - feasible: Whether this is computationally reasonable - time_estimate: Rough time estimate (assuming 1µs per query) - description: Human-readable description """ op_info = QUERY_COMPLEXITY.get( operation, {"safe": False, "queries": lambda n, k: 2**self.n, "description": "Unknown"} ) queries = op_info["queries"](self.n, num_queries) feasible = queries <= self.max_queries # Rough time estimate at 1µs per query time_us = queries if time_us < 1000: time_str = f"{time_us}µs" elif time_us < 1_000_000: time_str = f"{time_us/1000:.1f}ms" elif time_us < 1_000_000_000: time_str = f"{time_us/1_000_000:.1f}s" else: time_str = f"{time_us/1_000_000_000:.1f}ks (hours+)" return { "queries": queries, "feasible": feasible, "safe": op_info["safe"], "time_estimate": time_str, "description": op_info["description"], "access_type": self.access_type.name, }
[docs] def summary(self) -> Dict[str, Dict[str, Any]]: """Get cost summary for all operations.""" return {op: self.estimate_cost(op) for op in QUERY_COMPLEXITY}
[docs] def print_summary(self): """Print human-readable cost summary.""" print(f"Query Model Summary for n={self.n}") print(f"Access type: {self.access_type.name}") print(f"Max queries: {self.max_queries:,}") print("-" * 60) safe_ops = [] unsafe_ops = [] for op, info in QUERY_COMPLEXITY.items(): cost = self.estimate_cost(op) if info["safe"]: safe_ops.append((op, cost)) else: unsafe_ops.append((op, cost)) print("\n✓ SAFE operations (query-based):") for op, cost in safe_ops: print(f" {op}: ~{cost['queries']} queries ({cost['time_estimate']})") print("\n⚠ POTENTIALLY UNSAFE operations (may enumerate):") for op, cost in unsafe_ops: status = "✓" if cost["feasible"] else "✗" print(f" {status} {op}: ~{cost['queries']:,} queries ({cost['time_estimate']})")
def safe_alternatives(operation: str) -> Optional[str]: """ Suggest query-safe alternative for an unsafe operation. Args: operation: Unsafe operation name Returns: Name of safe alternative, or None if none exists """ alternatives = { "fourier": "estimate_fourier (sample-based)", "influences": "estimate_influence (sample-based)", "is_balanced": "is_balanced_approx (sample-based)", "total_influence": "estimate_total_influence", "is_junta": "detect_influential_vars (sample-based)", "constant_test": "is_constant_approx (sample-based)", } return alternatives.get(operation)