"""
Modules defining cross-correlated samples.
Has classes for both pure HOD cross-correlations
(i.e. number of cross-pairs) and for HaloModel-derived quantities
based on these cross-pairs.
To construct a :class:`CrossCorrelations` one need to specify the
halo models to be cross-correlated, and how they're correlated.
Examples
--------
Cross-correlating the same galaxy samples in different redshifts::
>>> from halomod import HaloModel
>>> from halomod.cross_correlations import CrossCorrelations, ConstantCorr
>>> cross = CrossCorrelations(cross_hod_model=ConstantCorr, halo_model_1_params=dict(z=1.0),
>>> halo_model_2_params=dict(z=0.0))
>>> pkcorr = cross.power_cross
`pkcorr` corresponds to the cross-power at `cross.halo_model_1.k_hm`.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import numpy as np
from hmf import Component, Framework
from hmf._internals._cache import cached_quantity, parameter, subframework
from hmf._internals._framework import get_mdl, pluggable
from scipy import integrate as intg
from scipy.interpolate import InterpolatedUnivariateSpline as _IUS
from . import tools
from .halo_model import TracerHaloModel
@pluggable
class _HODCross(ABC, Component):
"""Provides methods necessary to compute cross-correlation pairs for HOD models."""
_defaults = {}
def __init__(self, hods, **model_params):
super().__init__(**model_params)
assert len(hods) == 2
self.hods = hods
@abstractmethod
def R_ss(self, m):
r"""The cross-correlation of numbers of pairs within a halo.
Notes
-----
Defined by
.. math:: \langle T_1 T_2 \rangle = \langle T_1 \rangle \langle T_2 \rangle +
\sigma_1 \sigma_2 R_{ss},
where :math:`T` is the total amount of tracer in the halo's profile (i.e. not
counting the central component, if this exists).
"""
@abstractmethod
def R_cs(self, m):
r"""
The cross-correlation of central-satellite pairs within a halo.
Central from first hod, satellite from second.
Notes
-----
Defined by
.. math:: \langle T^c_1 T^s_2 \rangle =
\langle T^c_1 \rangle \langle T^s_2 \rangle +
\sigma^c_1 \sigma^s_2 R_{cs},
where :math:`T^s` is the total amount of tracer in the halo's profile (i.e. not
counting the central component,if this exists).
"""
@abstractmethod
def R_sc(self, m):
r"""
The cross-correlation of satellite-central pairs within a halo.
Central from second hod, Satellite from first.
Notes
-----
Defined by
.. math:: \langle T^s_1 T^c_2 \rangle =
\langle T^s_1 \rangle \langle T^c_2 \rangle +
\sigma^s_1 \sigma^c_2 R_{sc},
where :math:`T^s` is the total amount of tracer in the halo's profile (i.e. not
counting the central component,if this exists).
"""
@abstractmethod
def self_pairs(self, m):
r"""The expected number of cross-pairs at a separation of zero."""
def ss_cross_pairs(self, m):
r"""The average value of cross-pairs in a halo of mass m.
Notes
-----
Given by
.. math:: `\langle T^s_1 T^s_2 \rangle - Q`
"""
h1, h2 = self.hods
return (
h1.satellite_occupation(m) * h2.satellite_occupation(m)
+ h1.sigma_satellite(m) * h2.sigma_satellite(m) * self.R_ss(m)
- self.self_pairs(m)
)
def cs_cross_pairs(self, m):
r"""The average value of cross-pairs in a halo of mass m.
Notes
-----
.. math:: \langle T^c_1 T^s_2 \rangle.
"""
h1, h2 = self.hods
return h1.central_occupation(m) * h2.satellite_occupation(m) + h1.sigma_central(
m
) * h2.sigma_satellite(m) * self.R_cs(m)
def sc_cross_pairs(self, m):
r"""The average value of cross-pairs in a halo of mass m,.
Notes
-----
.. math:: \langle T^s_1 T^c_2 \rangle
"""
h1, h2 = self.hods
return h2.central_occupation(m) * h1.satellite_occupation(m) + h2.sigma_central(
m
) * h1.sigma_satellite(m) * self.R_sc(m)
[docs]
class ConstantCorr(_HODCross):
"""Correlation relation for constant cross-correlation pairs."""
_defaults = {"R_ss": 0.0, "R_cs": 0.0, "R_sc": 0.0}
[docs]
def R_ss(self, m):
return self.params["R_ss"]
[docs]
def R_cs(self, m):
return self.params["R_cs"]
[docs]
def R_sc(self, m):
return self.params["R_sc"]
[docs]
def self_pairs(self, m):
"""The expected number of cross-pairs at a separation of zero."""
return 0
[docs]
class CrossCorrelations(Framework):
r"""
The Framework for cross-correlations.
This class generates two :class:`~halomod.halo_model.TracerHaloModel`,
and calculates their cross-correlation according to the cross-correlation
model given.
Parameters
----------
cross_hod_model : class
Model for the HOD of cross correlation.
cross_hod_params : dict
Parameters for HOD used in cross-correlation.
halo_model_1_params,halo_model_2_params : dict
Parameters for the tracers used in cross-correlation.
"""
def __init__(
self,
cross_hod_model,
cross_hod_params: dict | None = None,
halo_model_1_params: dict | None = None,
halo_model_2_params: dict | None = None,
):
super().__init__()
self.cross_hod_model = cross_hod_model
self.cross_hod_params = cross_hod_params or {}
self._halo_model_1_params = halo_model_1_params or {}
self._halo_model_2_params = halo_model_2_params or {}
@parameter("model")
def cross_hod_model(self, val):
return get_mdl(val, "_HODCross")
@parameter("param")
def cross_hod_params(self, val):
return val
@subframework
def halo_model_1(self) -> TracerHaloModel:
"""Halo Model of the first tracer."""
return TracerHaloModel(**self._halo_model_1_params)
@subframework
def halo_model_2(self) -> TracerHaloModel:
"""Halo Model of the second tracer."""
return TracerHaloModel(**self._halo_model_2_params)
# ===========================================================================
# Cross-correlations
# ===========================================================================
@cached_quantity
def cross_hod(self):
"""HOD model of the cross-correlation."""
return self.cross_hod_model(
[self.halo_model_1.hod, self.halo_model_2.hod], **self.cross_hod_params
)
@cached_quantity
def power_1h_cross_fnc(self):
"""Total 1-halo cross-power."""
hm1, hm2 = self.halo_model_1, self.halo_model_2
mask = np.logical_and(
np.logical_and(
np.logical_not(np.isnan(self.cross_hod.ss_cross_pairs(hm1.m))),
np.logical_not(np.isnan(self.cross_hod.sc_cross_pairs(hm1.m))),
),
np.logical_not(np.isnan(self.cross_hod.cs_cross_pairs(hm1.m))),
)
m = hm1.m[mask]
u1 = hm1.tracer_profile_ukm[:, mask]
u2 = hm2.tracer_profile_ukm[:, mask]
integ = hm1.dndm[mask] * (
u1 * u2 * self.cross_hod.ss_cross_pairs(m)
+ u1 * self.cross_hod.sc_cross_pairs(m)
+ u2 * self.cross_hod.cs_cross_pairs(m)
)
p = intg.simpson(integ, x=m)
p /= hm1.mean_tracer_den * hm2.mean_tracer_den
return tools.ExtendedSpline(hm1.k, p, lower_func="power_law", upper_func="power_law")
@property
def power_1h_cross(self):
"""Total 1-halo cross-power."""
return self.power_1h_cross_fnc(self.halo_model_1.k_hm)
@cached_quantity
def corr_1h_cross_fnc(self):
"""The 1-halo term of the cross correlation."""
corr = tools.hankel_transform(self.power_1h_cross_fnc, self.halo_model_1._r_table, "r")
return tools.ExtendedSpline(
self.halo_model_1._r_table,
corr,
lower_func="power_law",
upper_func=tools._zero,
)
@cached_quantity
def corr_1h_cross(self):
"""The 1-halo term of the cross correlation."""
return self.corr_1h_cross_fnc(self.halo_model_1.r)
@cached_quantity
def power_2h_cross_fnc(self):
"""The 2-halo term of the cross-power spectrum.
Uses spline integration from ``tracer_mmin`` (the same lower-bound
convention as :attr:`~halomod.HaloModel.power_2h_auto_tracer`) so
that the result varies smoothly as the HOD ``M_min`` parameter is
changed and is consistent with the auto-power spectrum when both
tracer populations are identical.
"""
hm1, hm2 = self.halo_model_1, self.halo_model_2
u1 = hm1.tracer_profile_ukm # (k, m)
u2 = hm2.tracer_profile_ukm # (k, m)
# Build integrands including the m-Jacobian for log-mass integration:
# ∫ f(m) dm = ∫ f(m)·m d(ln m)
# Each model uses its own halo_bias so grids don't need to match.
integrand1 = hm1.dndm * hm1.halo_bias * hm1._total_occupation * u1 * hm1.m # (k, m)
integrand2 = hm2.dndm * hm2.halo_bias * hm2._total_occupation * u2 * hm2.m # (k, m)
lnm1 = np.log(hm1.m)
lnm2 = np.log(hm2.m)
xmin1 = hm1.tracer_mmin
if xmin1 is not None and xmin1 > hm1.m[0]:
lnxmin1 = np.log(xmin1)
b1 = np.apply_along_axis(
lambda f: _IUS(lnm1, f).integral(lnxmin1, lnm1[-1]), -1, integrand1
)
else:
b1 = intg.simpson(integrand1, x=lnm1)
xmin2 = hm2.tracer_mmin
if xmin2 is not None and xmin2 > hm2.m[0]:
lnxmin2 = np.log(xmin2)
b2 = np.apply_along_axis(
lambda f: _IUS(lnm2, f).integral(lnxmin2, lnm2[-1]), -1, integrand2
)
else:
b2 = intg.simpson(integrand2, x=lnm2)
p = (
b1
* b2
* hm1._power_halo_centres_fnc(hm1.k)
/ (hm1.mean_tracer_den * hm2.mean_tracer_den)
)
return tools.ExtendedSpline(
hm1.k,
p,
lower_func=hm1.linear_power_fnc,
match_lower=True,
upper_func="power_law",
)
@property
def power_2h_cross(self):
"""The 2-halo term of the cross-power spectrum."""
return self.power_2h_cross_fnc(self.halo_model_1.k_hm)
@cached_quantity
def corr_2h_cross_fnc(self):
"""The 2-halo term of the cross-correlation."""
corr = tools.hankel_transform(
self.power_2h_cross_fnc, self.halo_model_1._r_table, "r", h=1e-4
)
return tools.ExtendedSpline(
self.halo_model_1._r_table,
corr,
lower_func="power_law",
upper_func=tools._zero,
)
@cached_quantity
def corr_2h_cross(self):
"""The 2-halo term of the cross-correlation."""
return self.corr_2h_cross_fnc(self.halo_model_1.r)
[docs]
def power_cross_fnc(self, k):
"""Total tracer cross power spectrum."""
return self.power_1h_cross_fnc(k) + self.power_2h_cross_fnc(k)
@property
def power_cross(self):
"""Total tracer cross power spectrum."""
return self.power_cross_fnc(self.halo_model_1.k_hm)
[docs]
def corr_cross_fnc(self, r):
"""The tracer cross correlation function."""
return self.corr_1h_cross_fnc(r) + self.corr_2h_cross_fnc(r) + 1
@property
def corr_cross(self):
"""The tracer cross correlation function."""
return self.corr_cross_fnc(self.halo_model_1.r)