Source code for sas.fit.expression

# This program is public domain
"""
Parameter expression evaluator.

For systems in which constraints are expressed as string expressions rather
than python code, :func:`compile_constraints` can construct an expression
evaluator that substitutes the computed values of the expressions into the
parameters.

The compiler requires a symbol table, an expression set and a context.
The symbol table maps strings containing fully qualified names such as
'M1.c[3].full_width' to parameter objects with a 'value' property that
can be queried and set.  The expression set maps symbol names from the
symbol table to string expressions.  The context provides additional symbols
for the expressions in addition to the usual mathematical functions and
constants.

The expressions are compiled and interpreted by python, with only minimal
effort to make sure that they don't contain bad code.  The resulting
constraints function returns 0 so it can be used directly in a fit problem
definition.

Extracting the symbol table from the model depends on the structure of the
model.  If fitness.parameters() is set correctly, then this should simply
be a matter of walking the parameter data, remembering the path to each
parameter in the symbol table.  For compactness, dictionary elements should
be referenced by .name rather than ["name"].  Model name can be used as the
top level.

Getting the parameter expressions applied correctly is challenging.
The following monkey patch works by overriding model_update in FitProblem
so that after setp(p) is called and, the constraints expression can be
applied before telling the underlying fitness function that the model
is out of date::

        # Override model update so that parameter constraints are applied
        problem._model_update = problem.model_update
        def model_update():
            constraints()
            problem._model_update()
        problem.model_update = model_update

Ideally, this interface will change
"""
import math
import re

# simple pattern which matches symbols.  Note that it will also match
# invalid substrings such as a3...9, but given syntactically correct
# input it will only match symbols.
_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')

def _symbols(expr,symtab):
    """
    Given an expression string and a symbol table, return the set of symbols
    used in the expression.  Symbols are only returned once even if they
    occur multiple times.  The return value is a set with the elements in
    no particular order.
    
    This is the first step in computing a dependency graph.
    """
    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
    return set([symtab[m] for m in matches if m in symtab])

def _substitute(expr,mapping):
    """
    Replace all occurrences of symbol s with mapping[s] for s in mapping.
    """
    # Find the symbols and the mapping
    matches = [(m.start(),m.end(),mapping[m.group(1)])
               for m in _symbol_pattern.finditer(expr)
               if m.group(1) in mapping]

    # Split the expression in to pieces, with new symbols replacing old
    pieces = []
    offset = 0
    for start,end,text in matches:
        pieces += [expr[offset:start],text]
        offset = end
    pieces.append(expr[offset:])
    
    # Join the pieces and return them
    return "".join(pieces)

def _find_dependencies(symtab, exprs):
    """
    Returns a list of pair-wise dependencies from the parameter expressions.
    
    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
    such as p4 = 2*pi, this should return [(p4, None)]
    """
    deps = [(target,source)
            for target,expr in exprs.items()
            for source in _symbols_or_none(expr,symtab)]
    return deps

# Hack to deal with expressions without dependencies --- return a fake
# dependency of None.
# The better solution is fix order_dependencies so that it takes a
# dictionary of {symbol: dependency_list}, for which no dependencies
# is simply []; fix in parameter_mapping as well
def _symbols_or_none(expr,symtab):
    syms = _symbols(expr,symtab)
    return syms if len(syms) else [None]

def _parameter_mapping(pairs):
    """
    Find the parameter substitution we need so that expressions can
    be evaluated without having to traverse a chain of 
    model.layer.parameter.value
    """
    left,right = zip(*pairs)
    pars = list(sorted(p for p in set(left+right) if p is not None))
    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
    # p is None when there is an expression with no dependencies
    substitution = dict( (p,'P%d.value'%i)
                    for i,p in enumerate(sorted(pars))
                    if p is not None)
    return definition, substitution

[docs]def no_constraints(): """ This parameter set has no constraints between the parameters. """ pass
[docs]def compile_constraints(symtab, exprs, context={}): """ Build and return a function to evaluate all parameter expressions in the proper order. Input: *symtab* is the symbol table for the model: { 'name': parameter } *exprs* is the set of computed symbols: { 'name': 'expression' } *context* is any additional context needed to evaluate the expression Return: updater function which sets parameter.value for each expression Raises: AssertionError - model, parameter or function is missing SyntaxError - improper expression syntax ValueError - expressions have circular dependencies This function is not terribly sophisticated, and it would be easy to trick. However it handles the common cases cleanly and generates reasonable messages for the common errors. This code has not been fully audited for security. While we have removed the builtins and the ability to import modules, there may be other vectors for users to perform more than simple function evaluations. Unauthenticated users should not be running this code. Parameter names are assumed to contain only _.a-zA-Z0-9#[] Both names are provided for inverse functions, e.g., acos and arccos. Should try running the function to identify syntax errors before running it in a fit. Use help(fn) to see the code generated for the returned function fn. dis.dis(fn) will show the corresponding python vm instructions. """ # Sort the parameters in the order they need to be evaluated deps = _find_dependencies(symtab, exprs) if deps == []: return no_constraints order = order_dependencies(deps) # Rather than using the full path to the parameters in the parameter # expressions, instead use Pn, and substitute Pn.value for each occurrence # of the parameter in the expression. names = list(sorted(symtab.keys())) parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names)) mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names)) # Initialize dictionary with available functions globals = {} globals.update(math.__dict__) globals.update(dict(arcsin=math.asin,arccos=math.acos, arctan=math.atan,arctan2=math.atan2)) globals.update(context) globals.update(parameters) globals['id'] = id locals = {} # Define the constraints function assignments = ["=".join((p,exprs[p])) for p in order] code = [_substitute(s, mapping) for s in assignments] functiondef = """ def eval_expressions(): ''' %s ''' %s return 0 """%("\n ".join(assignments),"\n ".join(code)) #print("Function: "+functiondef) exec functiondef in globals,locals retfn = locals['eval_expressions'] # Remove garbage added to globals by exec globals.pop('__doc__',None) globals.pop('__name__',None) globals.pop('__file__',None) globals.pop('__builtins__') #print globals.keys() return retfn
[docs]def order_dependencies(pairs): """ Order elements from pairs so that b comes before a in the ordered list for all pairs (a,b). """ #print "order_dependencies",pairs emptyset = set() order = [] # Break pairs into left set and right set left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[]) while pairs != []: #print "within",pairs # Find which items only occur on the right independent = right - left if independent == emptyset: cycleset = ", ".join(str(s) for s in left) raise ValueError,"Cyclic dependencies amongst %s"%cycleset # The possibly resolvable items are those that depend on the independents dependent = set([a for a,b in pairs if b in independent]) pairs = [(a,b) for a,b in pairs if b not in independent] if pairs == []: resolved = dependent else: left,right = [set(s) for s in zip(*pairs)] resolved = dependent - left #print "independent",independent,"dependent",dependent,"resolvable",resolved order += resolved #print "new order",order order.reverse() return order # ========= Test code ========
def _check(msg,pairs): """ Verify that the list n contains the given items, and that the list satisfies the partial ordering given by the pairs in partial order. """ left,right = zip(*pairs) if pairs != [] else ([],[]) items = set(left) n = order_dependencies(pairs) if set(n) != items or len(n) != len(items): n.sort() items = list(items); items.sort() raise Exception,"%s expect %s to contain %s for %s"%(msg,n,items,pairs) for lo,hi in pairs: if lo in n and hi in n and n.index(lo) >= n.index(hi): raise Exception,"%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs)
[docs]def test_deps(): import numpy # Null case _check("test empty",[]) # Some dependencies _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]) _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)]) _check("test1 numpy",numpy.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])) # No dependencies _check("test2",[(4,1),(3,2),(8,4)]) # Cycle test pairs = [(1,4),(4,3),(4,5),(5,1)] try: n = order_dependencies(pairs) except ValueError: pass else: raise Exception,"test3 expect ValueError exception for %s"%(pairs,) # large test for gross speed check A = numpy.random.randint(4000,size=(1000,2)) A[:,1] += 4000 # Avoid cycles _check("test-large",A) # depth tests k = 200 A = numpy.array([range(0,k),range(1,k+1)]).T _check("depth-1",A) A = numpy.array([range(1,k+1),range(0,k)]).T _check("depth-2",A)
[docs]def test_expr(): import inspect, dis import math symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4} expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b' # Check symbol lookup assert _symbols(expr, symtab) == set([1,2,3]) # Check symbol rename assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b' assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q' # Check dependency builder # Fake parameter class class Parameter: def __init__(self, name, value=0, expression=''): self.path = name self.value = value self.expression = expression def iscomputed(self): return (self.expression != '') def __repr__(self): return self.path def world(*pars): symtab = dict((p.path,p) for p in pars) exprs = dict((p.path,p.expression) for p in pars if p.iscomputed()) return symtab, exprs p1 = Parameter('G0.sigma',5) p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1') p3 = Parameter('M1.G1',6) p4 = Parameter('constant',expression='2*pi*35') # Simple chain assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)]) # Constant expression assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)]) # No dependencies assert set(_find_dependencies(*world(p1,p3))) == set([]) # Check function builder fn = compile_constraints(*world(p1,p2,p3)) # Inspect the resulting function if 0: print(inspect.getdoc(fn)) print(dis.dis(fn)) # Evaluate the function and see if it updates the # target value as expected fn() expected = 2*math.pi*math.sin(5/.1875) + 6 assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected) # Check empty dependency set doesn't crash fn = compile_constraints(*world(p1,p3)) fn() # Check that constants are evaluated properly fn = compile_constraints(*world(p4)) fn() assert p4.value == 2*math.pi*35 # Check additional context example; this also tests multiple # expressions class Table: Si = 2.09 values = {'Si': 2.07} tbl = Table() p5 = Parameter('lookup',expression="tbl.Si") fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) fn() assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value) p5.expression = "tbl.values['Si']" fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) fn() assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value) # Verify that we capture invalid expressions for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 'piddle', '5; import sys; print "p0wned"', '__import__("sys").argv']: try: p6 = Parameter('broken',expression=expr) fn = compile_constraints(*world(p6)) fn() except Exception as msg: #print(msg) pass else: raise "Failed to raise error for %s"%expr
if __name__ == "__main__": test_expr() test_deps()