|
| 1 | +import dask.array as da |
| 2 | +import xarray as xr |
| 3 | + |
| 4 | +from sgkit.typing import ArrayLike |
| 5 | + |
| 6 | + |
| 7 | +def gramian(a: ArrayLike) -> ArrayLike: |
| 8 | + """Returns gramian matrix of the given matrix""" |
| 9 | + return a.T.dot(a) |
| 10 | + |
| 11 | + |
| 12 | +def pc_relate(ds: xr.Dataset, maf: float = 0.01) -> xr.Dataset: |
| 13 | + """Compute PC-Relate as described in Conomos, et al. 2016 [1]. |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + ds : `xr.Dataset` |
| 18 | + Dataset containing (S = num samples, V = num variants, D = ploidy, PC = num PC): |
| 19 | + * genotype calls: "call_genotype" (SxVxD) |
| 20 | + * genotype calls mask: "call_genotype_mask" (SxVxD) |
| 21 | + * sample PCs: "sample_pcs" (PCxS) |
| 22 | + maf : float |
| 23 | + individual minor allele frequency filter. If an individual's estimated |
| 24 | + individual-specific minor allele frequency at a SNP is less than this value, |
| 25 | + that SNP will be excluded from the analysis for that individual. |
| 26 | + The default value is 0.01. Must be between (0.0, 0.1). |
| 27 | +
|
| 28 | + Warnings |
| 29 | + -------- |
| 30 | + This function is only applicable to diploid, biallelic datasets. |
| 31 | +
|
| 32 | + Returns |
| 33 | + ------- |
| 34 | + Dataset |
| 35 | + Dataset containing (S = num samples): |
| 36 | + pc_relate_phi: (S,S) ArrayLike |
| 37 | + pairwise recent kinship estimation matrix as float in [-0.5, 0.5]. |
| 38 | +
|
| 39 | + References |
| 40 | + ---------- |
| 41 | + - [1] Conomos MP, Reiner AP, Weir BS & Thornton TA. 2016. |
| 42 | + "Model-free Estimation of Recent Genetic Relatedness." |
| 43 | + Am. J. Hum. Genet 98, 127–148. |
| 44 | +
|
| 45 | + Raises |
| 46 | + ------ |
| 47 | + ValueError |
| 48 | + * If ploidy of provided dataset != 2 |
| 49 | + * If maximum number of alleles in provided dataset != 2 |
| 50 | + * Input dataset is missing any of the required variables |
| 51 | + * If maf is not in (0.0, 1.0) |
| 52 | + """ |
| 53 | + if maf <= 0.0 or maf >= 1.0: |
| 54 | + raise ValueError("MAF must be between (0.0, 1.0)") |
| 55 | + if "ploidy" in ds.dims and ds.dims["ploidy"] != 2: |
| 56 | + raise ValueError("PC Relate only work for diploid genotypes") |
| 57 | + if "alleles" in ds.dims and ds.dims["alleles"] != 2: |
| 58 | + raise ValueError("PC Relate only work for biallelic genotypes") |
| 59 | + if "call_genotype" not in ds: |
| 60 | + raise ValueError("Input dataset must contain call_genotype") |
| 61 | + if "call_genotype_mask" not in ds: |
| 62 | + raise ValueError("Input dataset must contain call_genotype_mask") |
| 63 | + if "sample_pcs" not in ds: |
| 64 | + raise ValueError("Input dataset must contain sample_pcs variable") |
| 65 | + |
| 66 | + call_g_mask = ds["call_genotype_mask"].any(dim="ploidy") |
| 67 | + call_g = xr.where(call_g_mask, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] |
| 68 | + |
| 69 | + # impute with variant mean |
| 70 | + variant_mean = ( |
| 71 | + call_g.where(~call_g_mask) |
| 72 | + .mean(dim="samples") |
| 73 | + .expand_dims(dim="samples", axis=1) |
| 74 | + ) |
| 75 | + imputed_g = da.where(call_g_mask, variant_mean, call_g) |
| 76 | + |
| 77 | + # 𝔼[gs|V] = 1β0 + Vβ, where 1 is a length _s_ vector of 1s, and β = (β1,...,βD)^T |
| 78 | + # is a length D vector of regression coefficients for each of the PCs |
| 79 | + pcs = ds["sample_pcs"] |
| 80 | + pcsi = da.concatenate([da.ones((1, pcs.shape[1]), dtype=pcs.dtype), pcs], axis=0) |
| 81 | + # Note: dask qr decomp requires no chunking in one dimension, and because number of |
| 82 | + # components should be smaller than number of samples in most cases, we disable |
| 83 | + # chunking on number components |
| 84 | + pcsi = pcsi.T.rechunk((None, -1)) |
| 85 | + |
| 86 | + q, r = da.linalg.qr(pcsi) |
| 87 | + # mu, eq: 3 |
| 88 | + half_beta = da.linalg.inv(2 * r).dot(q.T).dot(imputed_g.T) |
| 89 | + mu = pcsi.dot(half_beta).T |
| 90 | + # phi, eq: 4 |
| 91 | + mask = (mu <= maf) | (mu >= 1.0 - maf) | call_g_mask |
| 92 | + mu_mask = da.ma.masked_array(mu, mask=mask) |
| 93 | + variance = mu_mask.map_blocks(lambda i: i * (1.0 - i)) |
| 94 | + variance = da.ma.filled(variance, fill_value=0.0) |
| 95 | + stddev = da.sqrt(variance) |
| 96 | + centered_af = call_g / 2 - mu_mask |
| 97 | + centered_af = da.ma.filled(centered_af, fill_value=0.0) |
| 98 | + # NOTE: gramian could be a performance bottleneck, and we could explore |
| 99 | + # performance improvements like (or maybe sth else): |
| 100 | + # * calculating only the pairs we are interested in |
| 101 | + # * using an optimized einsum. |
| 102 | + phi = gramian(centered_af) / gramian(stddev) |
| 103 | + return xr.Dataset({"pc_relate_phi": (("sample_x", "sample_y"), phi)}) |
0 commit comments