Source code for sas.sascalc.fit.expression

# This program is public domain
"""
Parameter expression evaluator.

For systems in which constraints are expressed as string expressions rather
than python code, :func:`compile_constraints` can construct an expression
evaluator that substitutes the computed values of the expressions into the
parameters.

The compiler requires a symbol table, an expression set and a context.
The symbol table maps strings containing fully qualified names such as
'M1.c[3].full_width' to parameter objects with a 'value' property that
can be queried and set.  The expression set maps symbol names from the
symbol table to string expressions.  The context provides additional symbols
for the expressions in addition to the usual mathematical functions and
constants.

The expressions are compiled and interpreted by python, with only minimal
effort to make sure that they don't contain bad code.  The resulting
constraints function returns 0 so it can be used directly in a fit problem
definition.

Extracting the symbol table from the model depends on the structure of the
model.  If fitness.parameters() is set correctly, then this should simply
be a matter of walking the parameter data, remembering the path to each
parameter in the symbol table.  For compactness, dictionary elements should
be referenced by .name rather than ["name"].  Model name can be used as the
top level.

Getting the parameter expressions applied correctly is challenging.
The following monkey patch works by overriding model_update in FitProblem
so that after setp(p) is called and, the constraints expression can be
applied before telling the underlying fitness function that the model
is out of date::

        # Override model update so that parameter constraints are applied
        problem._model_update = problem.model_update
        def model_update():
            constraints()
            problem._model_update()
        problem.model_update = model_update

Ideally, this interface will change
"""
from __future__ import print_function

from copy import copy
import math
import re
from keyword import iskeyword

[docs]def standard_symbols(context={}): symbols = {} symbols.update(math.__dict__) symbols.update(dict(arcsin=math.asin, arccos=math.acos, arctan=math.atan, arctan2=math.atan2)) symbols.update(context) symbols['id'] = id return symbols
def _check_syntax(target, expr, html=False): try: compile(expr, expr, "exec") except SyntaxError as exc: if html: if "\n" in expr: # Multiline expression. Be lazy and just show line, col for # syntax error since this should never happen. return [ f"Syntax error on line {exc.lineno} column {exc.offset} for {target}:\n<pre>\n{expr}\n</pre>"] if exc.offset > len(expr): # Single line expression with error after the expression. # Probably missing a closing paren, but it could # also be that the expression ends with an operator such as # "3+4+" return [f"Expression `{target} = {expr}` is not complete."] # Single line expression. Wrap the tail of the expr after the # syntax error in <b>...</b>. exc.lineno=1, exc.text = expr+"\n", # and exc.offset = location of the syntax error in expr. return [ f"Syntax error in expression '{target} = {expr[:exc.offset - 1]}<b>{expr[exc.offset - 1:]}</b>'"] else: return ["Syntax error in expression '%s = %s'" % (target, expr)] return [] def _check_free_variables(target, expr, symbol_table, html=False): undefined = [sym for sym in _symbols(expr) if sym not in symbol_table and not iskeyword(sym)] if undefined: undefined_str = ", ".join(sorted(undefined)) if html: for symbol in undefined: # Identify the symbol for replacement as everything between a # word boundary. Since symbols can contain '.', we need to # use negative lookbehind and negative lookahead on a # character set containing '.' to check for the boundaries. # Also, if there is a '.' in the symbol it could match any # character, so it needs to be turned into a regular # expression that matches '.'. pattern = f"(?<![a-zA-Z0-9_.]){symbol.replace('.', '[.]')}(?![a-zA-Z0-9_.])" expr = re.sub(pattern, f"<b>{symbol}</b>", expr) return ["Unknown parameters (%s) in expression '%s = %s'" % (undefined_str, target, expr)] else: return ["Unknown parameters (%s) in expression '%s = %s'" % (undefined_str, target, expr)] return [] # simple pattern which matches symbols. Note that it will also match # invalid substrings such as a3...9, but given syntactically correct # input it will only match symbols. _SYMBOL_PATTERN = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)') def _symbols(expr): return set(m.group(0) for m in _SYMBOL_PATTERN.finditer(expr)) def _substitute(expr, mapping): """ Replace all occurrences of symbol s with mapping[s] for s in mapping. """ # Find the symbols and the mapping matches = [(m.start(), m.end(), mapping[m.group(1)]) for m in _SYMBOL_PATTERN.finditer(expr) if m.group(1) in mapping] # Split the expression in to pieces, with new symbols replacing old pieces = [] offset = 0 for start, end, text in matches: pieces += [expr[offset:start], text] offset = end pieces.append(expr[offset:]) # Join the pieces and return them return "".join(pieces) def _find_dependencies(symtab, exprs): """ Returns a list of pair-wise dependencies from the parameter expressions. *symtab* gives the *{'parameter': value}* table of available parameters. *exprs* gives the *{'parameter': 'expr'}* table of parameter expressions. For example, if p3 = p1+p2, then find_dependencies([p1, p2, p3]) will return *[('p3', 'p1'), ('p3', 'p2')]*. For base expressions without dependencies, such as p4 = 2*pi, this should return *[('p4', None)]*. """ deps = [(target, source) for target, expr in exprs.items() for source in _dependent_symbols(expr, symtab)] return deps # Hack to deal with expressions without dependencies --- return a fake # dependency of None. # The better solution is fix order_dependencies so that it takes a # dictionary of {symbol: dependency_list}, for which no dependencies # is simply []; fix in parameter_mapping as well def _dependent_symbols(expr, symtab): """ Given an expression string and a symbol table, return the set of symbols used in the expression. Symbols are only returned once even if they occur multiple times. The return value is a set with the elements in no particular order. Returns a set containing *None* if no dependencies (as needed by *order_dependencies*). This is the first step in computing a dependency graph. """ deps = set([m for m in _symbols(expr) if m in symtab]) return deps if deps else {None} def _parameter_mapping(pairs): """ Find the parameter substitution we need so that expressions can be evaluated without having to traverse a chain of model.layer.parameter.value """ left, right = zip(*pairs) pars = list(sorted(p for p in set(left+right) if p is not None)) definition = dict( ('P%d'%i, p) for i, p in enumerate(pars) ) # p is None when there is an expression with no dependencies substitution = dict((p, 'P%d.value'%i) for i, p in enumerate(sorted(pars)) if p is not None) return definition, substitution
[docs]def no_constraints(): """ This parameter set has no constraints between the parameters. """ pass
[docs]def compile_constraints(symtab, exprs, context={}): """ Build and return a function to evaluate all parameter expressions in the proper order. Input: *symtab* is the symbol table for the model: { 'name': parameter } *exprs* is the set of computed symbols: { 'name': 'expression' } *context* is any additional context needed to evaluate the expression Return: updater function which sets parameter.value for each expression Raises: RunTimeError if any expression contains a syntax error, if any symbol used is not defined, or if there are circular dependencies between symbols. Runtime error argument is a string describing all found errors. This function is not terribly sophisticated, and it would be easy to trick. However it handles the common cases cleanly and generates reasonable messages for the common errors. This code has not been fully audited for security. While we have removed the builtins and the ability to import modules, there may be other vectors for users to perform more than simple function evaluations. Unauthenticated users should not be running this code. Parameter names are assumed to contain only _.a-zA-Z0-9#[] Both names are provided for inverse functions, e.g., acos and arccos. Should try running the function to identify syntax errors before running it in a fit. Use help(fn) to see the code generated for the returned function fn. dis.dis(fn) will show the corresponding python vm instructions. """ retfn, errors = _compile_constraints(symtab, exprs, context=context) if errors: raise RuntimeError("\n".join(errors)) return retfn
# Simple parameter class for checking constraints class _Parameter: def __init__(self, value=0): self.value = value
[docs]def check_constraints(symtab, exprs, context={}, html=False): """ Returns a list of errors in *exprs* or the empty list if there are none. If the html flag is set to True, the list elements will have html <b> markups that allow the caller to control rendering: Unknown symbol: tags unknown symbols in *exprs* Syntax error: tags the beginning of a syntax error in *exprs* Cyclic dependency: tags comma separated parameters that have cyclic dependency All symbols must exist in *context* or in *symtab*. The symbols in *context* should be constants or functions. The symbols in *symtab* can be constants or parameter objects with a *value* attribute. It first runs :func:`compile_constraints`, returning a list of errors if any. If there are no errors it runs the compiled constraints function, returning any errors it produces. Any parameters in *symtab* are copied with a shallow copy so they aren't overridden when the constraints are run. """ # Make sure the symbols are wrapped in parameter objects. This allows # us to run the function and check constraints. symtab = {k: (copy(v) if hasattr(v, 'value') else _Parameter(v)) for k, v in symtab.items()} retfn, errors = _compile_constraints(symtab, exprs, context=context, html=html) if not errors: try: retfn() except Exception as exc: errors.append(str(exc)) return errors
def _compile_constraints(symtab, exprs, context={}, html=False): errors = [] # Check the syntax before compiling the complete function. available_symbols = standard_symbols(context) available_symbols.update(symtab) for k, v in exprs.items(): errors.extend(_check_syntax(k, v, html=html)) errors.extend(_check_free_variables(k, v, available_symbols, html=html)) # Sort the parameters in the order they need to be evaluated # Note: order_dependencies raises an error if there are cyclic dependencies deps = _find_dependencies(symtab, exprs) if not deps: return no_constraints, [] try: order = order_dependencies(deps) except Exception as exc: if html: errors.append("Cyclic dependency amongst parameters: " "<b>%s</b>" % str(exc)) else: errors.append("Cyclic dependency amongst parameters: %s" % str(exc)) if errors: return None, errors # Rather than using the full path to the parameters in the parameter # expressions, instead use Pn, and substitute Pn.value for each occurrence # of the parameter in the expression. names = list(sorted(symtab.keys())) parameters = dict(('P%d'%i, symtab[k]) for i, k in enumerate(names)) mapping = dict((k, 'P%d.value'%i) for i, k in enumerate(names)) # Add the parameters to the global context global_context = standard_symbols(context) global_context.update(parameters) local_context = {} # Define the constraints function assignments = ["=".join((p, exprs[p])) for p in order] code = [_substitute(s, mapping) for s in assignments] # TODO: maybe wrap body in try...except block and return NaN? functiondef = """ def eval_expressions(): ''' %s ''' %s return 0 """%("\n ".join(assignments), "\n ".join(code)) #print(" ", "\n ".join(code)) #print("Function: "+functiondef) # CRUFT: python < 3.0; doc builder isn't allowing the following exec # https://stackoverflow.com/questions/4484872/why-doesnt-exec-work-in-a-function-with-a-subfunction/41368813#comment73790496_41368813 #exec(functiondef, global_context, local_context) source = functiondef location = "\n ".join(assignments) eval(compile(source, location, 'exec'), global_context, local_context) retfn = local_context['eval_expressions'] # Remove garbage added to globals by exec global_context.pop('__doc__', None) global_context.pop('__name__', None) global_context.pop('__file__', None) global_context.pop('__builtins__') #print globals.keys() return retfn, errors
[docs]def order_dependencies(pairs): """ Order elements from pairs so that b comes before a in the ordered list for all pairs (a, b). """ #print("order_dependencies", pairs) emptyset = set() order = [] # Break pairs into left set and right set # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty. left, right = [set(s) for s in zip(*pairs)] if len(pairs) > 0 else ([], []) while len(pairs) > 0: #print "within", pairs # Find which items only occur on the right independent = right - left if independent == emptyset: cycleset = ", ".join(str(s) for s in left) raise ValueError(cycleset) # The possibly resolvable items are those that depend on the independents dependent = set([a for a, b in pairs if b in independent]) pairs = [(a, b) for a, b in pairs if b not in independent] if len(pairs) == 0: resolved = dependent else: left, right = [set(s) for s in zip(*pairs)] resolved = dependent - left #print "independent", independent, "dependent", dependent, "resolvable", resolved order += resolved #print "new order", order order.reverse() return order
# ========= Test code ======== def _check(msg, pairs): """ Verify that the list n contains the given items, and that the list satisfies the partial ordering given by the pairs in partial order. """ # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty. left, right = zip(*pairs) if len(pairs) > 0 else ([], []) items = set(left) n = order_dependencies(pairs) if set(n) != items or len(n) != len(items): n.sort() items = list(items) items.sort() raise ValueError("%s expect %s to contain %s for %s" % (msg, n, items, pairs)) for lo, hi in pairs: if lo in n and hi in n and n.index(lo) >= n.index(hi): raise ValueError("%s expect %s before %s in %s for %s" % (msg, lo, hi, n, pairs))
[docs]def test_deps(): import numpy as np # Null case _check("test empty", []) # Some dependencies _check("test1", [(2, 7), (1, 5), (1, 4), (2, 1), (3, 1), (5, 6)]) _check("test1 renumbered", [(6, 1), (7, 3), (7, 4), (6, 7), (5, 7), (3, 2)]) _check("test1 numpy", np.array([(2, 7), (1, 5), (1, 4), (2, 1), (3, 1), (5, 6)])) # No dependencies _check("test2", [(4, 1), (3, 2), (8, 4)]) # Cycle test pairs = [(1, 4), (4, 3), (4, 5), (5, 1)] try: n = order_dependencies(pairs) except ValueError: pass else: raise ValueError("test3 expect ValueError exception for %s" % (pairs,)) # large test for gross speed check A = np.random.randint(4000, size=(1000, 2)) A[:, 1] += 4000 # Avoid cycles _check("test-large", A) # depth tests k = 200 A = np.array([range(0, k), range(1, k+1)]).T _check("depth-1", A) A = np.array([range(1, k+1), range(0, k)]).T _check("depth-2", A)
[docs]def test_expr(): import inspect import dis symtab = {'a.b.x': 1, 'a.c': 2, 'a.b': 3, 'b.x': 4} expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b' # Check symbol lookup assert _dependent_symbols(expr, symtab) == set(['a.b.x', 'a.c', 'a.b']) # Check symbol rename assert _substitute(expr, {'a.b.x': 'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b' assert _substitute(expr, {'a.b': 'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q' # Check dependency builder # Fake parameter class class TestParameter: def __init__(self, name, value=0, expression=''): self.path = name self.value = value self.expression = expression def iscomputed(self): return (self.expression != '') def __repr__(self): value = self.expression if self.iscomputed() else str(self.value) return self.path + '=' + value def world(*pars): symtab = dict((p.path, p) for p in pars) exprs = dict((p.path, p.expression) for p in pars if p.iscomputed()) return symtab, exprs p1 = TestParameter('G0.sigma', 5) p2 = TestParameter('other', expression='2*pi*sin(G0.sigma/.1875) + M1.G1') p3 = TestParameter('M1.G1', 6) p3_circular = TestParameter('M1.G1', expression='other + 6') p3_self = TestParameter('M1.G1', expression='M1.G1') p4 = TestParameter('constant', expression='2*pi*35') # Simple chain assert (set(_find_dependencies(*world(p1, p2, p3))) == set([(p2.path, p1.path), (p2.path, p3.path)])) # Constant expression assert set(_find_dependencies(*world(p1, p4))) == set([(p4.path, None)]) # No dependencies assert not set(_find_dependencies(*world(p1, p3))) # Check function builder fn = compile_constraints(*world(p1, p2, p3)) # Inspect the resulting function if 0: print(inspect.getdoc(fn)) print(dis.dis(fn)) # Evaluate the function and see if it updates the # target value as expected fn() expected = 2*math.pi*math.sin(5/.1875) + 6 assert p2.value == expected, "Value was %s, not %s" % (p2.value, expected) # Make sure check_constraints returns an empty list for these expressions. assert not check_constraints(*world(p1, p2, p3)) # Check empty dependency set doesn't crash fn = compile_constraints(*world(p1, p3)) fn() # Check that constants are evaluated properly fn = compile_constraints(*world(p4)) fn() assert p4.value == 2*math.pi*35 # Check that circular definitions get flagged try: fn = compile_constraints(*world(p1, p2, p3_circular)) except Exception as exc: assert str(exc).startswith('Cyclic') else: raise RuntimeError("failed to raise error for cyclic dependency") try: fn = compile_constraints(*world(p1, p2, p3_self)) except Exception as exc: assert str(exc).startswith('Cyclic') else: raise RuntimeError("failed to raise error for self dependency") # Make sure errors are returned from check_constraints errors = check_constraints(*world(p1, p2, p3_circular)) assert len(errors) == 1 and errors[0].startswith('Cyclic') errors = check_constraints(*world(p1, p2, p3_self)) assert len(errors) == 1 and errors[0].startswith('Cyclic') # Check additional context example; this also tests multiple expressions tbl = {'tbl_Si': 2.09} p5 = TestParameter('lookup', expression="tbl_Si") fn = compile_constraints(*world(p1, p2, p3, p5), context=tbl) fn() assert p5.value == 2.09, "Value for %s was %s" % (p5.expression, p5.value) #class Table: # Si = 2.09 # values = {'Si': 2.07} #tbl = Table() #p5 = TestParameter('lookup', expression="tbl.Si") #fn = compile_constraints(*world(p1, p2, p3, p5), context=dict(tbl=tbl)) #fn() #assert p5.value == 2.09, "Value for %s was %s"%(p5.expression, p5.value) # #p5.expression = "tbl.values['Si']" #fn = compile_constraints(*world(p1, p2, p3, p5), context=dict(tbl=tbl)) #fn() #assert p5.value == 2.07, "Value for %s was %s" % (p5.expression, p5.value) # Verify that we capture invalid expressions for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 'piddle', #'5; import sys; print("p0wned")', '__import__("sys").argv' ]: try: p6 = TestParameter('broken', expression=expr) fn = compile_constraints(*world(p6)) fn() except Exception as msg: #print(msg) pass else: raise RuntimeError("Failed to raise error for %s" % (expr,)) # Verify that check_constraints returns multiple errors. symtab = { 'M1.sld': 1.0, 'M1.sld_solvent': 6.0, 'M1.radius': 50, 'M1.scale': 1.0, } exprs = { 'M1.background': 'M1.scal/1e5', 'M1.sld': 'M1.sld_solvent + 1', 'M1.sld_solvent': 'M1.sld + 2', } errors = check_constraints(symtab, exprs) assert len(errors) == 2
#print("\n".join(errors)) if __name__ == "__main__": test_expr() test_deps()