Skip to content

Internal Reference

Julia Interface

Functions for initializing the Julia environment and installing deps.

init_julia(julia_project=None)

Initialize julia binary, turning off compiled modules if needed.

Source code in pysr/julia_helpers.py
def init_julia(julia_project=None):
    """Initialize julia binary, turning off compiled modules if needed."""
    global julia_initialized

    if not julia_initialized:
        _check_for_conflicting_libraries()

    from julia.core import JuliaInfo, UnsupportedPythonError

    _version_assertion()
    julia_project, is_shared = _process_julia_project(julia_project)
    _set_julia_project_env(julia_project, is_shared)

    try:
        info = JuliaInfo.load(julia="julia")
    except FileNotFoundError:
        env_path = os.environ["PATH"]
        raise FileNotFoundError(
            f"Julia is not installed in your PATH. Please install Julia and add it to your PATH.\n\nCurrent PATH: {env_path}",
        )

    if not info.is_pycall_built():
        raise ImportError(_import_error_string())

    Main = None
    try:
        from julia import Main as _Main

        Main = _Main
    except UnsupportedPythonError:
        # Static python binary, so we turn off pre-compiled modules.
        from julia.core import Julia

        jl = Julia(compiled_modules=False)
        from julia import Main as _Main

        Main = _Main

    julia_initialized = True
    return Main

install(julia_project=None, quiet=False)

Install PyCall.jl and all required dependencies for SymbolicRegression.jl.

Also updates the local Julia registry.

Source code in pysr/julia_helpers.py
def install(julia_project=None, quiet=False):  # pragma: no cover
    """
    Install PyCall.jl and all required dependencies for SymbolicRegression.jl.

    Also updates the local Julia registry.
    """
    import julia

    _version_assertion()
    # Set JULIA_PROJECT so that we install in the pysr environment
    julia_project, is_shared = _process_julia_project(julia_project)
    _set_julia_project_env(julia_project, is_shared)

    julia.install(quiet=quiet)

    if is_shared:
        # is_shared is only true if the julia_project arg was None
        # See _process_julia_project
        Main = init_julia(None)
    else:
        Main = init_julia(julia_project)

    Main.eval("using Pkg")

    io = "devnull" if quiet else "stderr"
    io_arg = f"io={io}" if is_julia_version_greater_eq(version=(1, 6, 0)) else ""

    # Can't pass IO to Julia call as it evaluates to PyObject, so just directly
    # use Main.eval:
    Main.eval(
        f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
    )
    if is_shared:
        # Install SymbolicRegression.jl:
        _add_sr_to_julia_project(Main, io_arg)

    Main.eval(f"Pkg.instantiate({io_arg})")
    Main.eval(f"Pkg.precompile({io_arg})")
    if not quiet:
        warnings.warn(
            "It is recommended to restart Python after installing PySR's dependencies,"
            " so that the Julia environment is properly initialized."
        )

Exporting to LaTeX

Functions to help export PySR equations to LaTeX.

to_latex(expr, prec=3, full_prec=True, **settings)

Convert sympy expression to LaTeX with custom precision.

Source code in pysr/export_latex.py
def to_latex(expr, prec=3, full_prec=True, **settings):
    """Convert sympy expression to LaTeX with custom precision."""
    settings["full_prec"] = full_prec
    printer = PreciseLatexPrinter(settings=settings, prec=prec)
    return printer.doprint(expr)

generate_single_table(equations, indices=None, precision=3, columns=['equation', 'complexity', 'loss', 'score'], max_equation_length=50, output_variable_name='y')

Generate a booktabs-style LaTeX table for a single set of equations.

Source code in pysr/export_latex.py
def generate_single_table(
    equations: pd.DataFrame,
    indices: List[int] = None,
    precision: int = 3,
    columns=["equation", "complexity", "loss", "score"],
    max_equation_length: int = 50,
    output_variable_name: str = "y",
):
    """Generate a booktabs-style LaTeX table for a single set of equations."""
    assert isinstance(equations, pd.DataFrame)

    latex_top, latex_bottom = generate_table_environment(columns)
    latex_table_content = []

    if indices is None:
        indices = range(len(equations))

    for i in indices:
        latex_equation = to_latex(
            equations.iloc[i]["sympy_format"],
            prec=precision,
        )
        complexity = str(equations.iloc[i]["complexity"])
        loss = to_latex(
            sympy.Float(equations.iloc[i]["loss"]),
            prec=precision,
        )
        score = to_latex(
            sympy.Float(equations.iloc[i]["score"]),
            prec=precision,
        )

        row_pieces = []
        for col in columns:
            if col == "equation":
                if len(latex_equation) < max_equation_length:
                    row_pieces.append(
                        "$" + output_variable_name + " = " + latex_equation + "$"
                    )
                else:

                    broken_latex_equation = " ".join(
                        [
                            r"\begin{minipage}{0.8\linewidth}",
                            r"\vspace{-1em}",
                            r"\begin{dmath*}",
                            output_variable_name + " = " + latex_equation,
                            r"\end{dmath*}",
                            r"\end{minipage}",
                        ]
                    )
                    row_pieces.append(broken_latex_equation)

            elif col == "complexity":
                row_pieces.append("$" + complexity + "$")
            elif col == "loss":
                row_pieces.append("$" + loss + "$")
            elif col == "score":
                row_pieces.append("$" + score + "$")
            else:
                raise ValueError(f"Unknown column: {col}")

        latex_table_content.append(
            " & ".join(row_pieces) + r" \\",
        )

    return "\n".join([latex_top, *latex_table_content, latex_bottom])

generate_multiple_tables(equations, indices=None, precision=3, columns=['equation', 'complexity', 'loss', 'score'], output_variable_names=None)

Generate multiple latex tables for a list of equation sets.

Source code in pysr/export_latex.py
def generate_multiple_tables(
    equations: List[pd.DataFrame],
    indices: List[List[int]] = None,
    precision: int = 3,
    columns=["equation", "complexity", "loss", "score"],
    output_variable_names: str = None,
):
    """Generate multiple latex tables for a list of equation sets."""
    # TODO: Let user specify custom output variable

    latex_tables = [
        generate_single_table(
            equations[i],
            (None if not indices else indices[i]),
            precision=precision,
            columns=columns,
            output_variable_name=(
                "y_{" + str(i) + "}"
                if output_variable_names is None
                else output_variable_names[i]
            ),
        )
        for i in range(len(equations))
    ]

    return "\n\n".join(latex_tables)

generate_table_environment(columns=['equation', 'complexity', 'loss'])

Source code in pysr/export_latex.py
def generate_table_environment(columns=["equation", "complexity", "loss"]):
    margins = "c" * len(columns)
    column_map = {
        "complexity": "Complexity",
        "loss": "Loss",
        "equation": "Equation",
        "score": "Score",
    }
    columns = [column_map[col] for col in columns]
    top_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]

    bottom_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    top_latex_table = "\n".join(top_pieces)
    bottom_latex_table = "\n".join(bottom_pieces)

    return top_latex_table, bottom_latex_table

Exporting to JAX

sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None)

Returns a function f and its parameters; the function takes an input matrix, and a list of arguments: f(X, parameters) where the parameters appear in the JAX equation.

Examples:

Let's create a function in SymPy:
```python
x, y = symbols('x y')
cosx = 1.0 * sympy.cos(x) + 3.2 * y
```
Let's get the JAX version. We pass the equation, and
the symbols required.
```python
f, params = sympy2jax(cosx, [x, y])
```
The order you supply the symbols is the same order
you should supply the features when calling
the function `f` (shape `[nrows, nfeatures]`).
In this case, features=2 for x and y.
The `params` in this case will be
`jnp.array([1.0, 3.2])`. You pass these parameters
when calling the function, which will let you change them
and take gradients.

Let's generate some JAX data to pass:
```python
key = random.PRNGKey(0)
X = random.normal(key, (10, 2))
```

We can call the function with:
```python
f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```

We can take gradients with respect
to the parameters for each row with JAX
gradient parameters now:
```python
jac_f = jax.jacobian(f, argnums=1)
jac_f(X, params)

#> DeviceArray([[ 0.49364874, -0.9692889 ],
#               [ 0.8283714 , -0.0318858 ],
#               [-0.7447336 , -1.8784496 ],
#               [ 0.70755106, -0.3137085 ],
#               [ 0.944834  ,  1.767703  ],
#               [ 0.51673377,  1.4111717 ],
#               [ 0.87347716, -0.52637756],
#               [ 0.8760679 ,  1.0549792 ],
#               [ 0.9961824 ,  0.79581654],
#               [-0.88465923, -0.5822907 ]], dtype=float32)
```

We can also JIT-compile our function:
```python
compiled_f = jax.jit(f)
compiled_f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```
Source code in pysr/export_jax.py
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
    """Returns a function f and its parameters;
    the function takes an input matrix, and a list of arguments:
            f(X, parameters)
    where the parameters appear in the JAX equation.

    # Examples:

        Let's create a function in SymPy:
        ```python
        x, y = symbols('x y')
        cosx = 1.0 * sympy.cos(x) + 3.2 * y
        ```
        Let's get the JAX version. We pass the equation, and
        the symbols required.
        ```python
        f, params = sympy2jax(cosx, [x, y])
        ```
        The order you supply the symbols is the same order
        you should supply the features when calling
        the function `f` (shape `[nrows, nfeatures]`).
        In this case, features=2 for x and y.
        The `params` in this case will be
        `jnp.array([1.0, 3.2])`. You pass these parameters
        when calling the function, which will let you change them
        and take gradients.

        Let's generate some JAX data to pass:
        ```python
        key = random.PRNGKey(0)
        X = random.normal(key, (10, 2))
        ```

        We can call the function with:
        ```python
        f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```

        We can take gradients with respect
        to the parameters for each row with JAX
        gradient parameters now:
        ```python
        jac_f = jax.jacobian(f, argnums=1)
        jac_f(X, params)

        #> DeviceArray([[ 0.49364874, -0.9692889 ],
        #               [ 0.8283714 , -0.0318858 ],
        #               [-0.7447336 , -1.8784496 ],
        #               [ 0.70755106, -0.3137085 ],
        #               [ 0.944834  ,  1.767703  ],
        #               [ 0.51673377,  1.4111717 ],
        #               [ 0.87347716, -0.52637756],
        #               [ 0.8760679 ,  1.0549792 ],
        #               [ 0.9961824 ,  0.79581654],
        #               [-0.88465923, -0.5822907 ]], dtype=float32)
        ```

        We can also JIT-compile our function:
        ```python
        compiled_f = jax.jit(f)
        compiled_f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```
    """
    _initialize_jax()
    global jax_initialized
    global jax
    global jnp
    global jsp

    parameters = []
    functional_form_text = sympy2jaxtext(
        expression, parameters, symbols_in, extra_jax_mappings
    )
    hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
    text = f"def {hash_string}(X, parameters):\n"
    if selection is not None:
        # Impose the feature selection:
        text += f"    X = X[:, {list(selection)}]\n"
    text += "    return "
    text += functional_form_text
    ldict = {}
    exec(text, globals(), ldict)
    return ldict[hash_string], jnp.array(parameters)

sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None)

Source code in pysr/export_jax.py
def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
    if issubclass(expr.func, sympy.Float):
        parameters.append(float(expr))
        return f"parameters[{len(parameters) - 1}]"
    elif issubclass(expr.func, sympy.Rational):
        return f"{float(expr)}"
    elif issubclass(expr.func, sympy.Integer):
        return f"{int(expr)}"
    elif issubclass(expr.func, sympy.Symbol):
        return (
            f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
        )
    if extra_jax_mappings is None:
        extra_jax_mappings = {}
    try:
        _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
    except KeyError:
        raise KeyError(
            f"Function {expr.func} was not found in JAX function mappings."
            "Please add it to extra_jax_mappings in the format, e.g., "
            "{sympy.sqrt: 'jnp.sqrt'}."
        )
    args = [
        sympy2jaxtext(
            arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
        )
        for arg in expr.args
    ]
    if _func == MUL:
        return " * ".join(["(" + arg + ")" for arg in args])
    if _func == ADD:
        return " + ".join(["(" + arg + ")" for arg in args])
    return f'{_func}({", ".join(args)})'

Exporting to PyTorch

sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None)

Returns a module for a given sympy expression with trainable parameters;

This function will assume the input to the module is a matrix X, where each column corresponds to each symbol you pass in symbols_in.

Source code in pysr/export_torch.py
def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
    """Returns a module for a given sympy expression with trainable parameters;

    This function will assume the input to the module is a matrix X, where
        each column corresponds to each symbol you pass in `symbols_in`.
    """
    global SingleSymPyModule

    _initialize_torch()

    return SingleSymPyModule(
        expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
    )