import base64
import importlib.resources
import io
from itertools import starmap
from pathlib import Path
from typing import Literal
import human_readable
import jinja2
import nested_pandas as npd
import numpy as np
import pandas as pd
from upath import UPath
from hats.catalog import CollectionProperties
from hats.catalog.catalog_collection import CatalogCollection
from hats.catalog.dataset import Dataset
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset
from hats.catalog.index.index_catalog import IndexCatalog
from hats.catalog.margin_cache.margin_catalog import MarginCatalog
from hats.io import get_common_metadata_pointer, get_partition_info_pointer, templates
from hats.io.file_io import get_upath, read_parquet_file_to_pandas
from hats.io.paths import get_data_thumbnail_pointer
from hats.loaders.read_hats import read_hats
def _cone_code_example(column_table: pd.DataFrame, cat_props) -> dict | None:
if "example" not in column_table:
return None
ra = np.round(float(column_table.loc[cat_props.ra_column]["example"]))
if ra >= 360.0:
ra -= 360.0
dec = np.round(float(column_table.loc[cat_props.dec_column]["example"]))
if dec >= 90.0:
dec = 89.9
if dec <= -90.0:
dec = -89.9
return {"ra": ra, "dec": dec}
def _gen_metadata_table(catalog: Dataset, total_columns: int | None) -> dict[str, object]:
props = catalog.catalog_info
has_healpix_column = props.healpix_column is not None
metadata_table = {}
if props.total_rows is not None:
metadata_table["Number of rows"] = f"{props.total_rows:,}"
if total_columns is not None:
key = "Number of columns"
value = f"{total_columns - int(has_healpix_column):,}"
if props.default_columns is not None:
key = "Number of columns (default columns)"
value = f"{value} ({len(props.default_columns):,})"
metadata_table[key] = value
if isinstance(catalog, HealpixDataset):
metadata_table["Number of partitions"] = f"{len(catalog.get_healpix_pixels()):,}"
if (hats_estsize_kb := props.extra_dict().get("hats_estsize")) is not None:
metadata_table["Size on disk"] = human_readable.file_size(int(hats_estsize_kb) * 1024, binary=True)
if (hats_builder := props.extra_dict().get("hats_builder")) is not None:
metadata_table["HATS Builder"] = hats_builder
return metadata_table
def _fmt_count_percent(n: int, total: int) -> str:
if n == 0:
return "0"
percent = round(n / total * 100, 2)
if percent < 0.01:
return f"{n:,} (<0.01%)"
return f"{n:,} ({percent}%)"
def _hard_truncate(s: str, limit: int) -> str:
if len(s) <= limit:
return s
return s[: limit - 1] + "…"
def _format_example_value(
value, *, float_precision: int = 4, soft_limit: int = 50, hard_limit: int = 70
) -> str:
"""Format an example value for display in a summary table.
Floats are rounded to a limited number of significant figures.
Lists are shown with as many items as fit within ``soft_limit``
characters (always at least one), with a ``(N total)`` suffix when
truncated. Any resulting string longer than ``hard_limit`` is
truncated with ``…``.
"""
if value is None:
return "*NULL*"
if isinstance(value, (float, np.floating)):
if np.isnan(value):
return "*NaN*"
if np.isinf(value):
return "-∞" if value < 0 else "∞"
return f"{value:.{float_precision}g}"
if isinstance(value, (list, tuple, np.ndarray)):
items = list(value)
if len(items) == 0:
return "[]"
fmt_kwargs = {"float_precision": float_precision, "soft_limit": soft_limit, "hard_limit": hard_limit}
suffix = f", … ({len(items)} total)]"
parts = [_format_example_value(items[0], **fmt_kwargs)]
for item in items[1:]:
candidate = _format_example_value(item, **fmt_kwargs)
preview = "[" + ", ".join(parts + [candidate]) + suffix
if len(preview) > soft_limit:
break
parts.append(candidate)
if len(parts) < len(items):
result = "[" + ", ".join(parts) + suffix
else:
result = "[" + ", ".join(parts) + "]"
else:
result = str(value)
return _hard_truncate(result, hard_limit)
def _build_column_table(
nf: npd.NestedFrame, default_columns, fmt_value=_format_example_value
) -> pd.DataFrame:
"""Build column info table from a NestedFrame and default column names."""
default_columns = frozenset(default_columns or [])
has_nested_columns = len(nf.nested_columns) > 0
has_example_row = not nf.empty
column = []
dtype = []
default = [] if len(default_columns) > 0 else None
nested_into = [] if has_nested_columns else None
example = [] if has_example_row else None
for name, dt in nf.dtypes.items():
cell = None if nf.empty else nf[name].iloc[0]
if isinstance(dt, npd.NestedDtype):
subcolumns = nf.get_subcolumns(name)
column.extend(subcolumns)
dtype.extend(f"list[{nf[sc].dtype.pyarrow_dtype}]" for sc in subcolumns)
if default is not None:
default.extend(name in default_columns or sc in default_columns for sc in subcolumns)
nested_into.extend([name] * len(subcolumns))
if example is not None:
if cell is None:
example_value = (fmt_value(None) for _ in subcolumns)
else:
example_value = (fmt_value(series.to_list()) for _, series in cell.items())
example.extend(example_value)
else:
column.append(name)
dtype.append(str(dt.pyarrow_dtype))
if default is not None:
default.append(name in default_columns)
if nested_into is not None:
nested_into.append(None)
if example is not None:
example.append(fmt_value(cell))
index = pd.Index(column, name="column")
result = pd.DataFrame(
{
"dtype": pd.Series(dtype, dtype=str, index=index),
},
index=index,
)
if default is not None:
result["default"] = pd.Series(default, dtype=bool, index=index)
if nested_into is not None:
result["nested_into"] = pd.Series(nested_into, dtype=str, index=index)
if example is not None:
result["example"] = pd.Series(example, dtype=object, index=index)
return result
def _gen_column_table(
catalog: Dataset, empty_nf: npd.NestedFrame | None, fmt_value=_format_example_value
) -> pd.DataFrame:
props = catalog.catalog_info
nf = _get_example_row(catalog)
if nf is None:
if empty_nf is None:
return pd.DataFrame()
nf = empty_nf
result = _build_column_table(nf, props.default_columns, fmt_value)
stats = catalog.aggregate_column_statistics(exclude_hats_columns=False)
if stats.empty:
return result
index = result.index
missed_columns = list(set(index) - set(stats.index))
def _fill_missed(series):
for col in missed_columns:
series.loc[col] = "*N/A*"
return series
result["min_value"] = _fill_missed(stats["min_value"].map(fmt_value))
result["max_value"] = _fill_missed(stats["max_value"].map(fmt_value))
row_count = stats["row_count"]
if np.any(row_count != props.total_rows):
result["rows"] = _fill_missed(row_count.map(lambda n: f"{n:,}"))
if stats["null_count"].sum() > 0:
null_count = stats["null_count"]
nulls = pd.Series(
list(starmap(_fmt_count_percent, zip(null_count, row_count))), dtype=str, index=stats.index
)
result["nulls"] = _fill_missed(nulls)
return result
def _join_catalog_uri(col_upath: str | None, path: str) -> str:
if col_upath is None:
return path
try:
upath = UPath(path)
except ValueError:
return path
if upath.protocol:
return path
try:
return str(UPath(col_upath) / path)
except ValueError:
return path
def _catalog_uris(properties: CollectionProperties, uri: str | None) -> dict[str, object]:
margin_urls = (properties.all_margins or []).copy()
if properties.default_margin is not None:
default_margin_idx = margin_urls.index(properties.default_margin)
margin_urls[0], margin_urls[default_margin_idx] = margin_urls[default_margin_idx], margin_urls[0]
index_columns = list(properties.all_indexes or {})
if properties.default_index is not None:
default_index_idx = index_columns.index(properties.default_index)
index_columns[0], index_columns[default_index_idx] = (
index_columns[default_index_idx],
index_columns[0],
)
return {
"collection": uri or "<PATH>",
"primary": {
"name": properties.hats_primary_table_url,
"uri": _join_catalog_uri(uri, properties.hats_primary_table_url),
},
"margins": [
{
"name": margin,
"uri": _join_catalog_uri(uri, margin),
}
for margin in margin_urls
],
"indexes": [
{
"column": column,
"name": properties.all_indexes[column],
"uri": _join_catalog_uri(uri, properties.all_indexes[column]),
}
for column in index_columns
],
}
def _get_example_frame(catalog: Dataset, rng: np.random.Generator) -> npd.NestedFrame | None:
if (root := catalog.catalog_path) is None or not root.exists():
return None
if (thumbnail_path := get_data_thumbnail_pointer(root)).exists():
return read_parquet_file_to_pandas(thumbnail_path, is_dir=False)
if not isinstance(catalog, HealpixDataset):
return None
healpix_pixels = catalog.get_healpix_pixels()
pixel = rng.choice(healpix_pixels)
return catalog.read_pixel_to_pandas(pixel)
def _get_example_row(catalog: HealpixDataset) -> npd.NestedFrame | None:
"""Returns a single-row nested frame with a random example row."""
random_seed = 42
rng = np.random.Generator(np.random.PCG64(random_seed))
example_nf = _get_example_frame(catalog, rng)
if example_nf is None:
return None
idx = rng.integers(len(example_nf))
return example_nf.iloc[idx : idx + 1]
# pylint: disable=import-outside-toplevel,import-error
def _generate_sky_coverage_images(catalog, name: str | None = None):
from matplotlib.colors import LogNorm
from hats.inspection.visualize_catalog import plot_density
pixel_title = f"Catalog pixel map - {name}" if name else None
density_title = f"Angular density of catalog {name}" if name else None
fig, _ = catalog.plot_pixels(plot_title=pixel_title)
pixel_map_b64 = _fig_to_webp_base64(fig)
fig, _ = plot_density(catalog, norm=LogNorm(), edgecolors="face", plot_title=density_title)
density_map_b64 = _fig_to_webp_base64(fig)
return pixel_map_b64, density_map_b64
# pylint: disable=import-outside-toplevel,import-error
def _fig_to_webp_base64(fig) -> str:
import matplotlib.pyplot as plt
buffer = io.BytesIO()
fig.savefig(buffer, format="webp", bbox_inches="tight")
plt.close(fig)
return base64.b64encode(buffer.getvalue()).decode("ascii")
# pylint: disable=import-outside-toplevel,import-error
[docs]
def write_skymap_png(catalog_path: str | Path | UPath) -> None:
"""Write a ``skymap.png`` pixel coverage map to the catalog directory.
Parameters
----------
catalog_path : str | Path | UPath
Path to the catalog directory. The PNG will be written alongside
the catalog's other files.
"""
import matplotlib.pyplot as plt
catalog = read_hats(get_upath(catalog_path))
inner = catalog.main_catalog if isinstance(catalog, CatalogCollection) else catalog
fig, _ = inner.plot_pixels()
with (get_upath(catalog_path) / "skymap.png").open("wb") as f:
fig.savefig(f, format="png", bbox_inches="tight")
plt.close(fig)
# pylint: disable=import-outside-toplevel,import-error
[docs]
def write_partition_info_png(catalog_path: str | Path | UPath) -> None:
"""Write a ``partition_info.png`` angular density map to the catalog directory.
Parameters
----------
catalog_path : str | Path | UPath
Path to the catalog directory. The PNG will be written alongside
the catalog's other files.
"""
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from hats.inspection.visualize_catalog import plot_density
catalog = read_hats(get_upath(catalog_path))
inner = catalog.main_catalog if isinstance(catalog, CatalogCollection) else catalog
fig, _ = plot_density(inner, norm=LogNorm(), edgecolors="face")
with (get_upath(catalog_path) / "partition_info.png").open("wb") as f:
fig.savefig(f, format="png", bbox_inches="tight")
plt.close(fig)
[docs]
def generate_summary(
catalog,
*,
fmt: Literal["markdown", "html"],
name: str,
description: str,
uri: str | None,
huggingface_metadata: bool,
jinja2_template: str | None = None,
) -> str:
"""Generate summary content for any HATS catalog or collection."""
if isinstance(catalog, CatalogCollection):
md_tmpl, html_tmpl = "collection_md_template.jinja2", "collection_html_template.jinja2"
elif isinstance(catalog, MarginCatalog):
md_tmpl, html_tmpl = "margin_md_template.jinja2", "margin_html_template.jinja2"
elif isinstance(catalog, IndexCatalog):
md_tmpl, html_tmpl = "index_md_template.jinja2", "index_html_template.jinja2"
else:
md_tmpl, html_tmpl = "catalog_md_template.jinja2", "catalog_html_template.jinja2"
env = jinja2.Environment(undefined=jinja2.StrictUndefined)
match fmt:
case "markdown":
tmpl_str = jinja2_template or importlib.resources.read_text(templates, md_tmpl)
case "html":
tmpl_str = jinja2_template or importlib.resources.read_text(templates, html_tmpl)
case _:
raise ValueError(f"Unsupported format: {fmt!r}. Expected 'markdown' or 'html'.")
template = env.from_string(tmpl_str)
is_collection = isinstance(catalog, CatalogCollection)
inner = catalog.main_catalog if is_collection else catalog
cat_props = inner.catalog_info
catalog_path = catalog.main_catalog_dir if is_collection else catalog.catalog_path
empty_nf = (
read_parquet_file_to_pandas(p) if (p := get_common_metadata_pointer(catalog_path)).exists() else None
)
column_table = _gen_column_table(inner, empty_nf)
col_props = catalog.collection_properties if is_collection else None
needs_sky = not isinstance(catalog, (MarginCatalog, IndexCatalog))
has_default_columns = bool(cat_props.default_columns) if needs_sky else None
cone_code_example = _cone_code_example(column_table, cat_props) if needs_sky else None
pixel_map_b64, density_map_b64 = None, None
if needs_sky:
try:
pixel_map_b64, density_map_b64 = _generate_sky_coverage_images(inner, name)
except ImportError:
pass
return template.render(
name=name,
description=description,
cat_props=cat_props,
uri=uri,
has_partition_info=get_partition_info_pointer(catalog_path).exists(),
huggingface_metadata=huggingface_metadata,
metadata_table=(
None
if isinstance(catalog, IndexCatalog)
else _gen_metadata_table(inner, total_columns=None if empty_nf is None else empty_nf.shape[1])
),
column_table=pd.DataFrame() if isinstance(catalog, IndexCatalog) else column_table,
catalog_dir_name=None if is_collection else catalog.catalog_path.name,
has_default_columns=has_default_columns,
cone_code_example=cone_code_example,
pixel_map_b64=pixel_map_b64,
density_map_b64=density_map_b64,
col_props=col_props,
uris=_catalog_uris(col_props, uri) if is_collection else None,
margin_thresholds=catalog.get_margin_thresholds() if is_collection else None,
)
[docs]
def write_catalog_summary_file(
catalog_path: str | Path | UPath,
*,
fmt: Literal["markdown", "html"],
filename: str | None = None,
output_dir: str | Path | UPath | None = None,
name: str | None = None,
description: str | None = None,
uri: str | None = None,
huggingface_metadata: bool = False,
jinja2_template: str | None = None,
) -> UPath:
"""Write a summary readme file for any HATS catalog or collection"""
from hats.catalog.catalog import Catalog
catalog_path = get_upath(catalog_path)
if fmt != "markdown" and huggingface_metadata:
raise ValueError("`huggingface_metadata=True` is supported only for `fmt='markdown'`")
match fmt:
case "markdown":
filename = filename or "README.md"
case "html":
filename = filename or "index.html"
case _:
raise ValueError(f"Unsupported format: {fmt!r}. Expected 'markdown' or 'html'.")
catalog = read_hats(catalog_path)
if isinstance(catalog, CatalogCollection):
name = name or catalog.collection_properties.name
description = description or f"This is the collection of HATS catalogs representing {name}."
elif isinstance(catalog, MarginCatalog):
name = name or catalog.catalog_info.catalog_name
description = description or f"This is the margin catalog for {name.removesuffix('_margin')}."
elif isinstance(catalog, IndexCatalog):
name = name or catalog.catalog_info.catalog_name
indexing_column = catalog.catalog_info.indexing_column
primary = catalog.catalog_info.primary_catalog
description = description or (
f"This index maps {indexing_column} values to partitions in the "
f"{primary} catalog for non-spatial lookups."
)
elif isinstance(catalog, Catalog):
name = name or catalog.catalog_info.catalog_name
description = description or f"This is the HATS catalog for {name}."
content = generate_summary(
catalog,
fmt=fmt,
name=name,
description=description,
uri=uri,
huggingface_metadata=huggingface_metadata,
jinja2_template=jinja2_template,
)
output_dir = catalog_path if output_dir is None else get_upath(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / filename
with output_path.open("w") as f:
f.write(content)
return output_path