Skip to content

Internal Reference

Julia Interface

Functions for initializing the Julia environment and installing deps.

init_julia(*args, **kwargs)

Source code in pysr/deprecated.py
def init_julia(*args, **kwargs):
    del args, kwargs
    warnings.warn(
        "The `init_julia` function has been removed. "
        "Julia is now initialized automatically at import time.",
        FutureWarning,
    )
    return jl

install(*args, **kwargs)

Source code in pysr/deprecated.py
def install(*args, **kwargs):
    del args, kwargs
    warnings.warn(
        "The `install` function has been removed. "
        "PySR now uses the `juliacall` package to install its dependencies automatically at import time. ",
        FutureWarning,
    )

Exporting to LaTeX

Functions to help export PySR equations to LaTeX.

generate_table_environment(columns=['equation', 'complexity', 'loss'])

Source code in pysr/export_latex.py
def generate_table_environment(
    columns: list[str] = ["equation", "complexity", "loss"]
) -> tuple[str, str]:
    margins = "c" * len(columns)
    column_map = {
        "complexity": "Complexity",
        "loss": "Loss",
        "equation": "Equation",
        "score": "Score",
    }
    columns = [column_map[col] for col in columns]
    top_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]

    bottom_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    top_latex_table = "\n".join(top_pieces)
    bottom_latex_table = "\n".join(bottom_pieces)

    return top_latex_table, bottom_latex_table

Exporting to JAX

sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None)

Returns a function f and its parameters; the function takes an input matrix, and a list of arguments: f(X, parameters) where the parameters appear in the JAX equation.

Examples:

Let's create a function in SymPy:
```python
x, y = symbols('x y')
cosx = 1.0 * sympy.cos(x) + 3.2 * y
```
Let's get the JAX version. We pass the equation, and
the symbols required.
```python
f, params = sympy2jax(cosx, [x, y])
```
The order you supply the symbols is the same order
you should supply the features when calling
the function `f` (shape `[nrows, nfeatures]`).
In this case, features=2 for x and y.
The `params` in this case will be
`jnp.array([1.0, 3.2])`. You pass these parameters
when calling the function, which will let you change them
and take gradients.

Let's generate some JAX data to pass:
```python
key = random.PRNGKey(0)
X = random.normal(key, (10, 2))
```

We can call the function with:
```python
f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```

We can take gradients with respect
to the parameters for each row with JAX
gradient parameters now:
```python
jac_f = jax.jacobian(f, argnums=1)
jac_f(X, params)

#> DeviceArray([[ 0.49364874, -0.9692889 ],
#               [ 0.8283714 , -0.0318858 ],
#               [-0.7447336 , -1.8784496 ],
#               [ 0.70755106, -0.3137085 ],
#               [ 0.944834  ,  1.767703  ],
#               [ 0.51673377,  1.4111717 ],
#               [ 0.87347716, -0.52637756],
#               [ 0.8760679 ,  1.0549792 ],
#               [ 0.9961824 ,  0.79581654],
#               [-0.88465923, -0.5822907 ]], dtype=float32)
```

We can also JIT-compile our function:
```python
compiled_f = jax.jit(f)
compiled_f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```
Source code in pysr/export_jax.py
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
    """Returns a function f and its parameters;
    the function takes an input matrix, and a list of arguments:
            f(X, parameters)
    where the parameters appear in the JAX equation.

    # Examples:

        Let's create a function in SymPy:
        ```python
        x, y = symbols('x y')
        cosx = 1.0 * sympy.cos(x) + 3.2 * y
        ```
        Let's get the JAX version. We pass the equation, and
        the symbols required.
        ```python
        f, params = sympy2jax(cosx, [x, y])
        ```
        The order you supply the symbols is the same order
        you should supply the features when calling
        the function `f` (shape `[nrows, nfeatures]`).
        In this case, features=2 for x and y.
        The `params` in this case will be
        `jnp.array([1.0, 3.2])`. You pass these parameters
        when calling the function, which will let you change them
        and take gradients.

        Let's generate some JAX data to pass:
        ```python
        key = random.PRNGKey(0)
        X = random.normal(key, (10, 2))
        ```

        We can call the function with:
        ```python
        f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```

        We can take gradients with respect
        to the parameters for each row with JAX
        gradient parameters now:
        ```python
        jac_f = jax.jacobian(f, argnums=1)
        jac_f(X, params)

        #> DeviceArray([[ 0.49364874, -0.9692889 ],
        #               [ 0.8283714 , -0.0318858 ],
        #               [-0.7447336 , -1.8784496 ],
        #               [ 0.70755106, -0.3137085 ],
        #               [ 0.944834  ,  1.767703  ],
        #               [ 0.51673377,  1.4111717 ],
        #               [ 0.87347716, -0.52637756],
        #               [ 0.8760679 ,  1.0549792 ],
        #               [ 0.9961824 ,  0.79581654],
        #               [-0.88465923, -0.5822907 ]], dtype=float32)
        ```

        We can also JIT-compile our function:
        ```python
        compiled_f = jax.jit(f)
        compiled_f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```
    """
    _initialize_jax()
    global jax_initialized
    global jax
    global jnp
    global jsp

    parameters = []
    functional_form_text = sympy2jaxtext(
        expression, parameters, symbols_in, extra_jax_mappings
    )
    hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
    text = f"def {hash_string}(X, parameters):\n"
    if selection is not None:
        # Impose the feature selection:
        text += f"    X = X[:, {list(selection)}]\n"
    text += "    return "
    text += functional_form_text
    ldict = {}
    exec(text, globals(), ldict)
    return ldict[hash_string], jnp.array(parameters)

sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None)

Source code in pysr/export_jax.py
def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
    if issubclass(expr.func, sympy.Float):
        parameters.append(float(expr))
        return f"parameters[{len(parameters) - 1}]"
    elif issubclass(expr.func, sympy.Rational) or issubclass(
        expr.func, sympy.NumberSymbol
    ):
        return f"{float(expr)}"
    elif issubclass(expr.func, sympy.Integer):
        return f"{int(expr)}"
    elif issubclass(expr.func, sympy.Symbol):
        return (
            f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
        )
    if extra_jax_mappings is None:
        extra_jax_mappings = {}
    try:
        _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
    except KeyError:
        raise KeyError(
            f"Function {expr.func} was not found in JAX function mappings."
            "Please add it to extra_jax_mappings in the format, e.g., "
            "{sympy.sqrt: 'jnp.sqrt'}."
        )
    args = [
        sympy2jaxtext(
            arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
        )
        for arg in expr.args
    ]
    if _func == MUL:
        return " * ".join(["(" + arg + ")" for arg in args])
    if _func == ADD:
        return " + ".join(["(" + arg + ")" for arg in args])
    return f'{_func}({", ".join(args)})'

Exporting to PyTorch

sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None)

Returns a module for a given sympy expression with trainable parameters;

This function will assume the input to the module is a matrix X, where each column corresponds to each symbol you pass in symbols_in.

Source code in pysr/export_torch.py
def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
    """Returns a module for a given sympy expression with trainable parameters;

    This function will assume the input to the module is a matrix X, where
        each column corresponds to each symbol you pass in `symbols_in`.
    """
    global SingleSymPyModule

    _initialize_torch()

    return SingleSymPyModule(
        expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
    )