phyloframe tree#

This example shows how to plot trees from the phyloframe library, using Seaborn to overlay scatter points.

Phyloframe represents phylogenies within DataFrames in the alife standard format (i.e., edge list). DataFrames may be passed to iplotx.tree() using phyloframe.legacy.alifestd_to_iplotx_pandas or phyloframe.legacy.alifestd_to_iplotx_polars.

import matplotlib.pyplot as plt
from phyloframe import legacy as pfl
import pandas as pd
import polars as pl
import seaborn as sns
import iplotx as ipx


def draw_scatter_tree(
    phylogeny_df,
    *,
    hue=None,
    size=None,
    style=None,
    c=None,
    ax=None,
    layout="vertical",
    scatter_kws=None,
    tree_kws=None,
):
    """Draw a phylogenetic tree with a seaborn scatter overlay.

    Adapted from `hstrat-synthesis
    <https://github.com/mmore500/hstrat-synthesis/blob/main/pylib/tree/_draw_scatter_tree.py>`_.

    Parameters
    ----------
    phylogeny_df : polars.DataFrame
        Alife-standard phylogeny with optional metadata columns.
    hue, size, style : str, optional
        Column names forwarded to :func:`seaborn.scatterplot`.
    c : str, sequence, or None
        Colour values; a column name or explicit array.
    ax : matplotlib.axes.Axes, optional
        Target axes.  Created if *None*.
    layout : str
        Tree layout forwarded to :func:`iplotx.tree`.
    scatter_kws : dict, optional
        Extra keyword arguments for :func:`seaborn.scatterplot`.
    tree_kws : dict, optional
        Extra keyword arguments for :func:`iplotx.tree`.
    """
    if ax is None:
        ax = plt.gca()
    if scatter_kws is None:
        scatter_kws = {}
    if tree_kws is None:
        tree_kws = {}

    try:
        phylogeny_df = phylogeny_df.to_pandas()
    except AttributeError:
        pass

    tree_artist = ipx.tree(
        pfl.alifestd_to_iplotx_pandas(phylogeny_df),
        ax=ax,
        layout=layout,
        **{"margins": 0.0, "edge_linewidth": 1.5, **tree_kws},
    )

    # Extract node positions — radial layouts need Cartesian offsets
    ipx_layout = tree_artist.get_layout()
    if layout == "radial":
        xs, ys = tree_artist.get_nodes().get_offsets().T
    else:
        xs, ys = ipx_layout.T.to_numpy()

    pos = {
        node._id: (x, y) for node, (x, y) in zip(ipx_layout.index, zip(xs, ys))
    }

    # Map positions back onto the DataFrame
    plot_df = phylogeny_df.assign(
        __x__=phylogeny_df["id"]
        .map({k: v[0] for k, v in pos.items()})
        .astype(float),
        __y__=phylogeny_df["id"]
        .map({k: v[1] for k, v in pos.items()})
        .astype(float),
    )

    # Resolve colour argument
    if isinstance(c, str):
        c = plot_df[c].fillna("none").tolist()
    elif c is None:
        c = "none"

    sns.scatterplot(
        plot_df,
        x="__x__",
        y="__y__",
        hue=hue,
        size=size,
        style=style,
        c=c,
        ax=ax,
        **{"legend": False, "zorder": 5, **scatter_kws},
    )
    return ax

Radial scatter tree#

A small phylogeny with scatter points displayed using the "radial" layout.

vertebrate_df = pd.DataFrame(
    {
        "id": [0, 1, 2, 3, 4, 5, 6, 7, 8],
        "ancestor_id": [0, 0, 0, 1, 1, 2, 2, 6, 6],
        "origin_time": [0, 2, 3, 5, 5, 6, 6, 7, 7],
        "taxon_label": [
            None,
            None,
            None,
            "Salmon",
            "Seahorse",
            "Parrot",
            None,
            "Cat",
            "Mouse",
        ],
        "group": [
            None,
            None,
            None,
            "fish",
            "fish",
            "bird",
            None,
            "mammal",
            "mammal",
        ],
    }
)

fig, ax = plt.subplots(figsize=(5, 5))
draw_scatter_tree(
    vertebrate_df.assign(taxon_label="\n" + vertebrate_df["taxon_label"]),
    hue="group",
    style="group",
    ax=ax,
    layout="radial",
    scatter_kws=dict(
        edgecolor="white",
        linewidth=0.8,
        legend="brief",
        markers=["o", "v", "^"],
        palette="Set2",
        s=220,
        zorder=1,
    ),
    tree_kws=dict(
        aspect=1,
        leaf_labels=True,
        margins=0.17,
        style=[
            "tree",
            dict(
                vertex=dict(
                    label=dict(
                        color="gray",
                    ),
                ),
                leaf=dict(
                    label=dict(
                        verticalalignment="baseline",
                        hmargin=12,
                    )

                ),
            ),
        ],
    ),
)
fig.tight_layout()
plt.show()
plot phyloframe tree

Vertical scatter tree#

A small phylogeny with scatter points displayed using the "vertical" layout.

body_weights = {
    "species": ["Lion", "Tiger", "Cougar", "Cheetah", "Domestic Cat"],
    "body weight": [190.0, 220.0, 70.0, 50.0, 4.5],
}
felid_df = (
    pfl.alifestd_from_newick_polars(
        "((Lion:0.05,Tiger:0.05):0.03,(Cougar:0.06,Cheetah:0.06):0.02,'Domestic Cat':0.05);",
    )
    .pipe(pfl.alifestd_mark_leaves_polars)
    .with_columns(
        domesticated=pl.when(pl.col("is_leaf"))
        .then(pl.col("taxon_label") == "Domestic Cat")
        .otherwise(None)
        .cast(str)
    )
    .join(
        pl.DataFrame(body_weights),
        left_on="taxon_label",
        right_on="species",
        how="left",
    )
)

fig2, ax2 = plt.subplots(figsize=(5, 5))
draw_scatter_tree(
    felid_df,
    hue="domesticated",
    size="body weight",
    ax=ax2,
    layout="vertical",
    scatter_kws=dict(
        edgecolor="none",
        linewidth=0.5,
        legend="brief",
        palette="RdBu",
        sizes=(60, 800),
        zorder=-100,
    ),
    tree_kws=dict(
        leaf_labels=True,
        margins=0.15,
    ),
)
sns.move_legend(ax2, "upper left", bbox_to_anchor=(1, 0.8), frameon=False)
fig2.tight_layout()
plot phyloframe tree

Total running time of the script: (0 minutes 2.675 seconds)

Gallery generated by Sphinx-Gallery