Skip to content

Internal Reference

Julia Interface

Functions for initializing the Julia environment and installing deps.

init_julia(julia_project=None, quiet=False, julia_kwargs=None, return_aux=False)

Initialize julia binary, turning off compiled modules if needed.

Source code in pysr/julia_helpers.py
def init_julia(julia_project=None, quiet=False, julia_kwargs=None, return_aux=False):
    """Initialize julia binary, turning off compiled modules if needed."""
    global julia_initialized
    global julia_kwargs_at_initialization
    global julia_activated_env

    if not julia_initialized:
        _check_for_conflicting_libraries()

    if julia_kwargs is None:
        julia_kwargs = {"optimize": 3}

    from julia.core import JuliaInfo, UnsupportedPythonError

    _julia_version_assertion()
    processed_julia_project, is_shared = _process_julia_project(julia_project)
    _set_julia_project_env(processed_julia_project, is_shared)

    try:
        info = JuliaInfo.load(julia="julia")
    except FileNotFoundError:
        env_path = os.environ["PATH"]
        raise FileNotFoundError(
            f"Julia is not installed in your PATH. Please install Julia and add it to your PATH.\n\nCurrent PATH: {env_path}",
        )

    if not info.is_pycall_built():
        raise ImportError(_import_error())

    from julia.core import Julia

    try:
        Julia(**julia_kwargs)
    except UnsupportedPythonError:
        # Static python binary, so we turn off pre-compiled modules.
        julia_kwargs = {**julia_kwargs, "compiled_modules": False}
        Julia(**julia_kwargs)
        warnings.warn(
            "Your system's Python library is static (e.g., conda), so precompilation will be turned off. For a dynamic library, try using `pyenv` and installing with `--enable-shared`: https://github.com/pyenv/pyenv/blob/master/plugins/python-build/README.md#building-with---enable-shared."
        )

    using_compiled_modules = (not "compiled_modules" in julia_kwargs) or julia_kwargs[
        "compiled_modules"
    ]

    from julia import Main as _Main

    Main = _Main

    if julia_activated_env is None:
        julia_activated_env = processed_julia_project

    if julia_initialized and julia_kwargs_at_initialization is not None:
        # Check if the kwargs are the same as the previous initialization
        init_set = set(julia_kwargs_at_initialization.items())
        new_set = set(julia_kwargs.items())
        set_diff = new_set - init_set
        # Remove the `compiled_modules` key, since it is not a user-specified kwarg:
        set_diff = {k: v for k, v in set_diff if k != "compiled_modules"}
        if len(set_diff) > 0:
            warnings.warn(
                "Julia has already started. The new Julia options "
                + str(set_diff)
                + " will be ignored."
            )

    if julia_initialized and julia_activated_env != processed_julia_project:
        Main.eval("using Pkg")

        io_arg = _get_io_arg(quiet)
        # Can't pass IO to Julia call as it evaluates to PyObject, so just directly
        # use Main.eval:
        Main.eval(
            f'Pkg.activate("{_escape_filename(processed_julia_project)}",'
            f"shared = Bool({int(is_shared)}), "
            f"{io_arg})"
        )

        julia_activated_env = processed_julia_project

    if not julia_initialized:
        julia_kwargs_at_initialization = julia_kwargs

    julia_initialized = True
    if return_aux:
        return Main, {"compiled_modules": using_compiled_modules}
    return Main

install(julia_project=None, quiet=False, precompile=None)

Install PyCall.jl and all required dependencies for SymbolicRegression.jl.

Also updates the local Julia registry.

Source code in pysr/julia_helpers.py
def install(julia_project=None, quiet=False, precompile=None):  # pragma: no cover
    """
    Install PyCall.jl and all required dependencies for SymbolicRegression.jl.

    Also updates the local Julia registry.
    """
    import julia

    _julia_version_assertion()
    # Set JULIA_PROJECT so that we install in the pysr environment
    processed_julia_project, is_shared = _process_julia_project(julia_project)
    _set_julia_project_env(processed_julia_project, is_shared)

    if precompile == False:
        os.environ["JULIA_PKG_PRECOMPILE_AUTO"] = "0"

    try:
        julia.install(quiet=quiet)
    except julia.tools.PyCallInstallError:
        # Attempt to reset PyCall.jl's build:
        subprocess.run(
            [
                "julia",
                "-e",
                f'ENV["PYTHON"] = "{sys.executable}"; import Pkg; Pkg.build("PyCall")',
            ],
        )
        # Try installing again:
        julia.install(quiet=quiet)

    Main, init_log = init_julia(julia_project, quiet=quiet, return_aux=True)
    io_arg = _get_io_arg(quiet)

    if precompile is None:
        precompile = init_log["compiled_modules"]

    if not precompile:
        Main.eval('ENV["JULIA_PKG_PRECOMPILE_AUTO"] = 0')

    if is_shared:
        # Install SymbolicRegression.jl:
        _add_sr_to_julia_project(Main, io_arg)

    Main.eval("using Pkg")
    Main.eval(f"Pkg.instantiate({io_arg})")

    if precompile:
        Main.eval(f"Pkg.precompile({io_arg})")

    if not quiet:
        warnings.warn(
            "It is recommended to restart Python after installing PySR's dependencies,"
            " so that the Julia environment is properly initialized."
        )

Exporting to LaTeX

Functions to help export PySR equations to LaTeX.

generate_table_environment(columns=['equation', 'complexity', 'loss'])

Source code in pysr/export_latex.py
def generate_table_environment(
    columns: List[str] = ["equation", "complexity", "loss"]
) -> Tuple[str, str]:
    margins = "c" * len(columns)
    column_map = {
        "complexity": "Complexity",
        "loss": "Loss",
        "equation": "Equation",
        "score": "Score",
    }
    columns = [column_map[col] for col in columns]
    top_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]

    bottom_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    top_latex_table = "\n".join(top_pieces)
    bottom_latex_table = "\n".join(bottom_pieces)

    return top_latex_table, bottom_latex_table

Exporting to JAX

sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None)

Returns a function f and its parameters; the function takes an input matrix, and a list of arguments: f(X, parameters) where the parameters appear in the JAX equation.

Examples:

Let's create a function in SymPy:
```python
x, y = symbols('x y')
cosx = 1.0 * sympy.cos(x) + 3.2 * y
```
Let's get the JAX version. We pass the equation, and
the symbols required.
```python
f, params = sympy2jax(cosx, [x, y])
```
The order you supply the symbols is the same order
you should supply the features when calling
the function `f` (shape `[nrows, nfeatures]`).
In this case, features=2 for x and y.
The `params` in this case will be
`jnp.array([1.0, 3.2])`. You pass these parameters
when calling the function, which will let you change them
and take gradients.

Let's generate some JAX data to pass:
```python
key = random.PRNGKey(0)
X = random.normal(key, (10, 2))
```

We can call the function with:
```python
f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```

We can take gradients with respect
to the parameters for each row with JAX
gradient parameters now:
```python
jac_f = jax.jacobian(f, argnums=1)
jac_f(X, params)

#> DeviceArray([[ 0.49364874, -0.9692889 ],
#               [ 0.8283714 , -0.0318858 ],
#               [-0.7447336 , -1.8784496 ],
#               [ 0.70755106, -0.3137085 ],
#               [ 0.944834  ,  1.767703  ],
#               [ 0.51673377,  1.4111717 ],
#               [ 0.87347716, -0.52637756],
#               [ 0.8760679 ,  1.0549792 ],
#               [ 0.9961824 ,  0.79581654],
#               [-0.88465923, -0.5822907 ]], dtype=float32)
```

We can also JIT-compile our function:
```python
compiled_f = jax.jit(f)
compiled_f(X, params)

#> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
#                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
#                3.5427954 , -2.7479894 ], dtype=float32)
```
Source code in pysr/export_jax.py
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
    """Returns a function f and its parameters;
    the function takes an input matrix, and a list of arguments:
            f(X, parameters)
    where the parameters appear in the JAX equation.

    # Examples:

        Let's create a function in SymPy:
        ```python
        x, y = symbols('x y')
        cosx = 1.0 * sympy.cos(x) + 3.2 * y
        ```
        Let's get the JAX version. We pass the equation, and
        the symbols required.
        ```python
        f, params = sympy2jax(cosx, [x, y])
        ```
        The order you supply the symbols is the same order
        you should supply the features when calling
        the function `f` (shape `[nrows, nfeatures]`).
        In this case, features=2 for x and y.
        The `params` in this case will be
        `jnp.array([1.0, 3.2])`. You pass these parameters
        when calling the function, which will let you change them
        and take gradients.

        Let's generate some JAX data to pass:
        ```python
        key = random.PRNGKey(0)
        X = random.normal(key, (10, 2))
        ```

        We can call the function with:
        ```python
        f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```

        We can take gradients with respect
        to the parameters for each row with JAX
        gradient parameters now:
        ```python
        jac_f = jax.jacobian(f, argnums=1)
        jac_f(X, params)

        #> DeviceArray([[ 0.49364874, -0.9692889 ],
        #               [ 0.8283714 , -0.0318858 ],
        #               [-0.7447336 , -1.8784496 ],
        #               [ 0.70755106, -0.3137085 ],
        #               [ 0.944834  ,  1.767703  ],
        #               [ 0.51673377,  1.4111717 ],
        #               [ 0.87347716, -0.52637756],
        #               [ 0.8760679 ,  1.0549792 ],
        #               [ 0.9961824 ,  0.79581654],
        #               [-0.88465923, -0.5822907 ]], dtype=float32)
        ```

        We can also JIT-compile our function:
        ```python
        compiled_f = jax.jit(f)
        compiled_f(X, params)

        #> DeviceArray([-2.6080756 ,  0.72633684, -6.7557726 , -0.2963162 ,
        #                6.6014843 ,  5.032483  , -0.810931  ,  4.2520013 ,
        #                3.5427954 , -2.7479894 ], dtype=float32)
        ```
    """
    _initialize_jax()
    global jax_initialized
    global jax
    global jnp
    global jsp

    parameters = []
    functional_form_text = sympy2jaxtext(
        expression, parameters, symbols_in, extra_jax_mappings
    )
    hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
    text = f"def {hash_string}(X, parameters):\n"
    if selection is not None:
        # Impose the feature selection:
        text += f"    X = X[:, {list(selection)}]\n"
    text += "    return "
    text += functional_form_text
    ldict = {}
    exec(text, globals(), ldict)
    return ldict[hash_string], jnp.array(parameters)

sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None)

Source code in pysr/export_jax.py
def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
    if issubclass(expr.func, sympy.Float):
        parameters.append(float(expr))
        return f"parameters[{len(parameters) - 1}]"
    elif issubclass(expr.func, sympy.Rational):
        return f"{float(expr)}"
    elif issubclass(expr.func, sympy.Integer):
        return f"{int(expr)}"
    elif issubclass(expr.func, sympy.Symbol):
        return (
            f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
        )
    if extra_jax_mappings is None:
        extra_jax_mappings = {}
    try:
        _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
    except KeyError:
        raise KeyError(
            f"Function {expr.func} was not found in JAX function mappings."
            "Please add it to extra_jax_mappings in the format, e.g., "
            "{sympy.sqrt: 'jnp.sqrt'}."
        )
    args = [
        sympy2jaxtext(
            arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
        )
        for arg in expr.args
    ]
    if _func == MUL:
        return " * ".join(["(" + arg + ")" for arg in args])
    if _func == ADD:
        return " + ".join(["(" + arg + ")" for arg in args])
    return f'{_func}({", ".join(args)})'

Exporting to PyTorch

sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None)

Returns a module for a given sympy expression with trainable parameters;

This function will assume the input to the module is a matrix X, where each column corresponds to each symbol you pass in symbols_in.

Source code in pysr/export_torch.py
def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
    """Returns a module for a given sympy expression with trainable parameters;

    This function will assume the input to the module is a matrix X, where
        each column corresponds to each symbol you pass in `symbols_in`.
    """
    global SingleSymPyModule

    _initialize_torch()

    return SingleSymPyModule(
        expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
    )