1import autograd.numpy as np
2from autograd.scipy.special import gammaln
3from autograd.scipy.misc import logsumexp
4from autograd.scipy.linalg import solve_triangular
6from ssm.util import one_hot
9def flatten_to_dim(X, d):
10    """
11    Flatten an array of dimension k + d into an array of dimension 1 + d.
13    Example:
14        X = npr.rand(10, 5, 2, 2)
15        flatten_to_dim(X, 4).shape # (10, 5, 2, 2)
16        flatten_to_dim(X, 3).shape # (10, 5, 2, 2)
17        flatten_to_dim(X, 2).shape # (50, 2, 2)
18        flatten_to_dim(X, 1).shape # (100, 2)
20    Parameters
21    ----------
22    X : array_like
23        The array to be flattened.  Must be at least d dimensional
25    d : int (> 0)
26        The number of dimensions to retain.  All leading dimensions are flattened.
28    Returns
29    -------
30    flat_X : array_like
31        The input X flattened into an array dimension d (if X.ndim == d)
32        or d+1 (if X.ndim > d)
33    """
34    assert X.ndim >= d
35    assert d > 0
36    return np.reshape(X[None, ...], (-1,) + X.shape[-d:])
39def batch_mahalanobis(L, x):
40    """
41    Compute the squared Mahalanobis distance.
42    :math:`x^T M^{-1} x` for a factored :math:`M = LL^T`.
44    Copied from PyTorch torch.distributions.multivariate_normal.
46    Parameters
47    ----------
48    L : array_like (..., D, D)
49        Cholesky factorization(s) of covariance matrix
51    x : array_like (..., D)
52        Points at which to evaluate the quadratic term
54    Returns
55    -------
56    y : array_like (...,)
57        squared Mahalanobis distance :math:`x^T (LL^T)^{-1} x`
59        x^T (LL^T)^{-1} x = x^T L^{-T} L^{-1} x
60    """
61    # The most common shapes are x: (T, D) and L : (D, D)
62    # Special case that one
63    if x.ndim == 2 and L.ndim == 2:
64        xs = solve_triangular(L, x.T, lower=True)
65        return np.sum(xs**2, axis=0)
67    # Flatten the Cholesky into a (-1, D, D) array
68    flat_L = flatten_to_dim(L, 2)
69    # Invert each of the K arrays and reshape like L
70    L_inv = np.reshape(np.array([np.linalg.inv(Li.T) for Li in flat_L]), L.shape)
71    # dot with L_inv^T; square and sum.
72    xs = np.einsum('...i,...ij->...j', x, L_inv)
73    return np.sum(xs**2, axis=-1)
75def _multivariate_normal_logpdf(data, mus, Sigmas, Ls=None):
76    """
77    Compute the log probability density of a multivariate Gaussian distribution.
78    This will broadcast as long as data, mus, Sigmas have the same (or at
79    least be broadcast compatible along the) leading dimensions.
81    Parameters
82    ----------
83    data : array_like (..., D)
84        The points at which to evaluate the log density
86    mus : array_like (..., D)
87        The mean(s) of the Gaussian distribution(s)
89    Sigmas : array_like (..., D, D)
90        The covariances(s) of the Gaussian distribution(s)
92    Ls : array_like (..., D, D)
93        Optionally pass in the Cholesky decomposition of Sigmas
95    Returns
96    -------
97    lps : array_like (...,)
98        Log probabilities under the multivariate Gaussian distribution(s).
99    """
100    # Check inputs
101    D = data.shape[-1]
102    assert mus.shape[-1] == D
103    assert Sigmas.shape[-2] == Sigmas.shape[-1] == D
104    if Ls is not None:
105        assert Ls.shape[-2] == Ls.shape[-1] == D
106    else:
107        Ls = np.linalg.cholesky(Sigmas)                              # (..., D, D)
109    # Quadratic term
110    lp = -0.5 * batch_mahalanobis(Ls, data - mus)                    # (...,)
111    # Normalizer
112    L_diag = np.reshape(Ls, Ls.shape[:-2] + (-1,))[..., ::D + 1]     # (..., D)
113    half_log_det = np.sum(np.log(abs(L_diag)), axis=-1)              # (...,)
114    lp = lp - 0.5 * D * np.log(2 * np.pi) - half_log_det             # (...,)
116    return lp
119def multivariate_normal_logpdf(data, mus, Sigmas, mask=None):
120    """
121    Compute the log probability density of a multivariate Gaussian distribution.
122    This will broadcast as long as data, mus, Sigmas have the same (or at
123    least compatible) leading dimensions.
125    Parameters
126    ----------
127    data : array_like (..., D)
128        The points at which to evaluate the log density
130    mus : array_like (..., D)
131        The mean(s) of the Gaussian distribution(s)
133    Sigmas : array_like (..., D, D)
134        The covariances(s) of the Gaussian distribution(s)
136    mask : array_like (..., D) bool
137        Optional mask indicating which entries in the data are observed
139    Returns
140    -------
141    lps : array_like (...,)
142        Log probabilities under the multivariate Gaussian distribution(s).
143    """
144    # Check inputs
145    D = data.shape[-1]
146    assert mus.shape[-1] == D
147    assert Sigmas.shape[-2] == Sigmas.shape[-1] == D
149    # If there's no mask, we can just use the standard log pdf code
150    if mask is None:
151        return _multivariate_normal_logpdf(data, mus, Sigmas)
153    # Otherwise we need to separate the data into sets with the same mask,
154    # since each one will entail a different covariance matrix.
155    #
156    # First, determine the output shape. Allow mus and Sigmas to
157    # have different shapes; e.g. many Gaussians with the same
158    # covariance but different means.
159    shp1 = np.broadcast(data, mus).shape[:-1]
160    shp2 = np.broadcast(data[..., None], Sigmas).shape[:-2]
161    assert len(shp1) == len(shp2)
162    shp = tuple(max(s1, s2) for s1, s2 in zip(shp1, shp2))
164    # Broadcast the data into the full shape
165    full_data = np.broadcast_to(data, shp + (D,))
167    # Get the full mask
168    assert mask.dtype == bool
169    assert mask.shape == data.shape
170    full_mask = np.broadcast_to(mask, shp + (D,))
172    # Flatten the mask and get the unique values
173    flat_data = flatten_to_dim(full_data, 1)
174    flat_mask = flatten_to_dim(full_mask, 1)
175    unique_masks, mask_index = np.unique(flat_mask, return_inverse=True, axis=0)
177    # Initialize the output
178    lls = np.nan * np.ones(flat_data.shape[0])
180    # Compute the log probability for each mask
181    for i, this_mask in enumerate(unique_masks):
182        this_inds = np.where(mask_index == i)[0]
183        this_D = np.sum(this_mask)
184        if this_D == 0:
185            lls[this_inds] = 0
186            continue
188        this_data = flat_data[np.ix_(this_inds, this_mask)]
189        this_mus = mus[..., this_mask]
190        this_Sigmas = Sigmas[np.ix_(*[np.ones(sz, dtype=bool) for sz in Sigmas.shape[:-2]], this_mask, this_mask)]
192        # Precompute the Cholesky decomposition
193        this_Ls = np.linalg.cholesky(this_Sigmas)
195        # Broadcast mus and Sigmas to full shape and extract the necessary indices
196        this_mus = flatten_to_dim(np.broadcast_to(this_mus, shp + (this_D,)), 1)[this_inds]
197        this_Ls = flatten_to_dim(np.broadcast_to(this_Ls, shp + (this_D, this_D)), 2)[this_inds]
199        # Evaluate the log likelihood
200        lls[this_inds] = _multivariate_normal_logpdf(this_data, this_mus, this_Sigmas, Ls=this_Ls)
202    # Reshape the output
203    assert np.all(np.isfinite(lls))
204    return np.reshape(lls, shp)
207def expected_multivariate_normal_logpdf(E_xs, E_xxTs, E_mus, E_mumuTs, Sigmas, Ls=None):
208    """
209    Compute the expected log probability density of a multivariate Gaussian distribution.
210    This will broadcast as long as data, mus, Sigmas have the same (or at
211    least be broadcast compatible along the) leading dimensions.
212    Parameters
213    ----------
214    E_xs : array_like (..., D)
215        The expected value of the points at which to evaluate the log density
216    E_xxTs : array_like (..., D, D)
217        The second moment of the points at which to evaluate the log density
218    E_mus : array_like (..., D)
219        The expected mean(s) of the Gaussian distribution(s)
220    E_mumuTs : array_like (..., D, D)
221        The second moment of the mean
222    Sigmas : array_like (..., D, D)
223        The covariances(s) of the Gaussian distribution(s)
224    Ls : array_like (..., D, D)
225        Optionally pass in the Cholesky decomposition of Sigmas
226    Returns
227    -------
228    lps : array_like (...,)
229        Expected log probabilities under the multivariate Gaussian distribution(s).
230    TODO
231    ----
232    - Allow for uncertainty in the covariance as well.
233    """
234    # Check inputs
235    D = E_xs.shape[-1]
236    assert E_xxTs.shape[-2] == E_xxTs.shape[-1] == D
237    assert E_mus.shape[-1] == D
238    assert E_mumuTs.shape[-2] == E_mumuTs.shape[-1] == D
239    assert Sigmas.shape[-2] == Sigmas.shape[-1] == D
240    if Ls is not None:
241        assert Ls.shape[-2] == Ls.shape[-1] == D
242    else:
243        Ls = np.linalg.cholesky(Sigmas)                              # (..., D, D)
245    # TODO: Figure out how to perform this computation without explicit inverse
246    Sigma_invs = np.linalg.inv(Sigmas)
248    # Compute  E[(x-mu)^T Sigma^{-1}(x-mu)]
249    #        = Tr(Sigma^{-1} E[(x-mu)(x-mu)^T])
250    #        = Tr(Sigma^{-1} E[xx^T - x mu^T - mu x^T + mu mu^T])
251    #        = Tr(Sigma^{-1} (E[xx^T - E[x]E[mu]^T - E[mu]E[x]^T + E[mu mu^T]]))
252    #        = Tr(Sigma^{-1} A)
253    #        = Tr((LL^T)^{-1} A)
254    #        = Tr(L^{-1} A L^{-T} )
255    #        = sum_{ij} [Sigma^{-1}]_{ij} * A_{ij}
256    # where
257    #
258    # A = E[xx^T - E[x]E[mu]^T - E[mu]E[x]^T + E[mu mu^T]]
259    #
260    # However, since Sigma^{-1} is symmetric, we get the same
261    # answer with
262    #
263    # A = E[xx^T - 2 * E[x]E[mu]^T + E[mu mu^T]]
264    #
265    E_xmuT = E_xs[..., :, None] * E_mus[..., None, :]
266    # E_muxT = np.swapaxes(E_xmuT, -1, -2)
267    # As = E_xxTs - E_xmuT - E_muxT + E_mumuTs
268    As = E_xxTs - 2 * E_xmuT + E_mumuTs
269    lp = -0.5 * np.sum(Sigma_invs * As, axis=(-2, -1))
271    # Normalizer
272    L_diag = np.reshape(Ls, Ls.shape[:-2] + (-1,))[..., ::D + 1]     # (..., D)
273    half_log_det = np.sum(np.log(abs(L_diag)), axis=-1)              # (...,)
274    lp = lp - 0.5 * D * np.log(2 * np.pi) - half_log_det             # (...,)
276    return lp
279def diagonal_gaussian_logpdf(data, mus, sigmasqs, mask=None):
280    """
281    Compute the log probability density of a Gaussian distribution with
282    a diagonal covariance.  This will broadcast as long as data, mus,
283    sigmas have the same (or at least compatible) leading dimensions.
285    Parameters
286    ----------
287    data : array_like (..., D)
288        The points at which to evaluate the log density
290    mus : array_like (..., D)
291        The mean(s) of the Gaussian distribution(s)
293    sigmasqs : array_like (..., D)
294        The diagonal variances(s) of the Gaussian distribution(s)
296    mask : array_like (..., D) bool
297        Optional mask indicating which entries in the data are observed
299    Returns
300    -------
301    lps : array_like (...,)
302        Log probabilities under the diagonal Gaussian distribution(s).
303    """
304    # Check inputs
305    D = data.shape[-1]
306    assert mus.shape[-1] == D
307    assert sigmasqs.shape[-1] == D
309    # Check mask
310    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
311    assert mask.shape == data.shape
313    normalizer = -0.5 * np.log(2 * np.pi * sigmasqs)
314    return np.sum((normalizer - 0.5 * (data - mus)**2 / sigmasqs) * mask, axis=-1)
317def multivariate_studentst_logpdf(data, mus, Sigmas, nus, Ls=None):
318    """
319    Compute the log probability density of a multivariate Student's t distribution.
320    This will broadcast as long as data, mus, Sigmas, nus have the same (or at
321    least be broadcast compatible along the) leading dimensions.
323    Parameters
324    ----------
325    data : array_like (..., D)
326        The points at which to evaluate the log density
328    mus : array_like (..., D)
329        The mean(s) of the t distribution(s)
331    Sigmas : array_like (..., D, D)
332        The covariances(s) of the t distribution(s)
334    nus : array_like (...,)
335        The degrees of freedom of the t distribution(s)
337    Ls : array_like (..., D, D)
338        Optionally pass in the Cholesky decomposition of Sigmas
340    Returns
341    -------
342    lps : array_like (...,)
343        Log probabilities under the multivariate Gaussian distribution(s).
344    """
345    # Check inputs
346    D = data.shape[-1]
347    assert mus.shape[-1] == D
348    assert Sigmas.shape[-2] == Sigmas.shape[-1] == D
349    if Ls is not None:
350        assert Ls.shape[-2] == Ls.shape[-1] == D
351    else:
352        Ls = np.linalg.cholesky(Sigmas)                              # (..., D, D)
354    # Quadratic term
355    q = batch_mahalanobis(Ls, data - mus) / nus                      # (...,)
356    lp = - 0.5 * (nus + D) * np.log1p(q)                             # (...,)
358    # Normalizer
359    lp = lp + gammaln(0.5 * (nus + D)) - gammaln(0.5 * nus)          # (...,)
360    lp = lp - 0.5 * D * np.log(np.pi) - 0.5 * D * np.log(nus)        # (...,)
361    L_diag = np.reshape(Ls, Ls.shape[:-2] + (-1,))[..., ::D + 1]     # (..., D)
362    half_log_det = np.sum(np.log(abs(L_diag)), axis=-1)              # (...,)
363    lp = lp - half_log_det
365    return lp
368def expected_multivariate_studentst_logpdf(E_xs, E_xxTs, E_mus, E_mumuTs, Sigmas, nus, Ls=None):
369    """
370    Compute the expected log probability density of a multivariate Gaussian distribution.
371    This will broadcast as long as data, mus, Sigmas have the same (or at
372    least be broadcast compatible along the) leading dimensions.
373    Parameters
374    ----------
375    E_xs : array_like (..., D)
376        The expected value of the points at which to evaluate the log density
377    E_xxTs : array_like (..., D, D)
378        The second moment of the points at which to evaluate the log density
379    E_mus : array_like (..., D)
380        The expected mean(s) of the Gaussian distribution(s)
381    E_mumuTs : array_like (..., D, D)
382        The second moment of the mean
383    Sigmas : array_like (..., D, D)
384        The covariances(s) of the Gaussian distribution(s)
385    Ls : array_like (..., D, D)
386        Optionally pass in the Cholesky decomposition of Sigmas
387    Returns
388    -------
389    lps : array_like (...,)
390        Expected log probabilities under the multivariate Gaussian distribution(s).
391    TODO
392    ----
393    - Allow for uncertainty in the covariance Sigmas and dof nus as well.
394    """
395    # Check inputs
396    D = E_xs.shape[-1]
397    assert E_xxTs.shape[-2] == E_xxTs.shape[-1] == D
398    assert E_mus.shape[-1] == D
399    assert E_mumuTs.shape[-2] == E_mumuTs.shape[-1] == D
400    assert Sigmas.shape[-2] == Sigmas.shape[-1] == D
401    if Ls is not None:
402        assert Ls.shape[-2] == Ls.shape[-1] == D
403    else:
404        Ls = np.linalg.cholesky(Sigmas)                              # (..., D, D)
406    # TODO: Figure out how to perform this computation without explicit inverse
407    Sigma_invs = np.linalg.inv(Sigmas)
409    # Compute  E[(x-mu)^T Sigma^{-1}(x-mu)]
410    #        = Tr(Sigma^{-1} E[(x-mu)(x-mu)^T])
411    #        = Tr(Sigma^{-1} E[xx^T - 2 x mu^T + mu mu^T])
412    #        = Tr(Sigma^{-1} (E[xx^T - 2 E[x]E[mu]^T + E[mu mu^T]]))
413    #        = Tr(Sigma^{-1} A)
414    #        = Tr((LL^T)^{-1} A)
415    #        = Tr(L^{-1} A L^{-T} )
416    #        = sum_{ij} [Sigma^{-1}]_{ij} * A_{ij}
417    # where
418    #
419    # A = E[xx^T - 2 E[x]E[mu]^T + E[mu mu^T]]
420    #
421    As = E_xxTs - 2 * E_xs[..., :, None] * E_mus[..., None, :] + E_mumuTs   # (..., D, D)
422    q = np.sum(Sigma_invs * As, axis=(-2, -1)) / nus                        # (...,)
423    lp = - 0.5 * (nus + D) * np.log1p(q)                                    # (...,)
425    # Normalizer
426    L_diag = np.reshape(Ls, Ls.shape[:-2] + (-1,))[..., ::D + 1]            # (..., D)
427    half_log_det = np.sum(np.log(abs(L_diag)), axis=-1)                     # (...,)
428    lp = lp - 0.5 * D * np.log(2 * np.pi) - half_log_det                    # (...,)
430    return lp
433def independent_studentst_logpdf(data, mus, sigmasqs, nus, mask=None):
434    """
435    Compute the log probability density of a Gaussian distribution with
436    a diagonal covariance.  This will broadcast as long as data, mus,
437    sigmas have the same (or at least compatible) leading dimensions.
439    Parameters
440    ----------
441    data : array_like (..., D)
442        The points at which to evaluate the log density
444    mus : array_like (..., D)
445        The mean(s) of the Student's t distribution(s)
447    sigmasqs : array_like (..., D)
448        The diagonal variances(s) of the Student's t distribution(s)
450    nus : array_like (..., D)
451        The degrees of freedom of the Student's t distribution(s)
453    mask : array_like (..., D) bool
454        Optional mask indicating which entries in the data are observed
456    Returns
457    -------
458    lps : array_like (...,)
459        Log probabilities under the Student's t distribution(s).
460    """
461    D = data.shape[-1]
462    assert mus.shape[-1] == D
463    assert sigmasqs.shape[-1] == D
464    assert nus.shape[-1] == D
466    # Check mask
467    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
468    assert mask.shape == data.shape
470    normalizer = gammaln(0.5 * (nus + 1)) - gammaln(0.5 * nus)
471    normalizer = normalizer - 0.5 * (np.log(np.pi) + np.log(nus) + np.log(sigmasqs))
472    ll = normalizer - 0.5 * (nus + 1) * np.log(1.0 + (data - mus)**2 / (sigmasqs * nus))
473    return np.sum(ll * mask, axis=-1)
476def bernoulli_logpdf(data, logit_ps, mask=None):
477    """
478    Compute the log probability density of a Bernoulli distribution.
479    This will broadcast as long as data and logit_ps have the same
480    (or at least compatible) leading dimensions.
482    Parameters
483    ----------
484    data : array_like (..., D)
485        The points at which to evaluate the log density
487    logit_ps : array_like (..., D)
488        The logit(s) log p / (1 - p) of the Bernoulli distribution(s)
490    mask : array_like (..., D) bool
491        Optional mask indicating which entries in the data are observed
493    Returns
494    -------
495    lps : array_like (...,)
496        Log probabilities under the Bernoulli distribution(s).
497    """
498    D = data.shape[-1]
499    assert (data.dtype == int or data.dtype == bool)
500    assert data.min() >= 0 and data.max() <= 1
501    assert logit_ps.shape[-1] == D
503    # Check mask
504    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
505    assert mask.shape == data.shape
507    # Evaluate log probability
508    # log Pr(x | p) = x * log(p) + (1-x) * log(1-p)
509    #               = x * log(p / (1-p)) + log(1-p)
510    #               = x * log(p / (1-p)) - log(1/(1-p))
511    #               = x * log(p / (1-p)) - log(1 + p/(1-p)).
512    #
513    # Let u = log (p / (1-p)) = logit(p), then
514    #
515    # log Pr(x | p) = x * u - log(1 + e^u)
516    #               = x * u - log(e^0 + e^u)
517    #               = x * u - log(e^m * (e^-m + e^(u-m))
518    #               = x * u - m - log(exp(-m) + exp(u-m)).
519    #
520    # This holds for any m. we choose m = max(0, u) to avoid overflow.
521    m = np.maximum(0, logit_ps)
522    lls = data * logit_ps - m - np.log(np.exp(-m) + np.exp(logit_ps - m))
523    return np.sum(lls * mask, axis=-1)
526def poisson_logpdf(data, lambdas, mask=None):
527    """
528    Compute the log probability density of a Poisson distribution.
529    This will broadcast as long as data and lambdas have the same
530    (or at least compatible) leading dimensions.
532    Parameters
533    ----------
534    data : array_like (..., D)
535        The points at which to evaluate the log density
537    lambdas : array_like (..., D)
538        The rates of the Poisson distribution(s)
540    mask : array_like (..., D) bool
541        Optional mask indicating which entries in the data are observed
543    Returns
544    -------
545    lps : array_like (...,)
546        Log probabilities under the Poisson distribution(s).
547    """
548    D = data.shape[-1]
549    assert data.dtype in (int, np.int8, np.int16, np.int32, np.int64)
550    assert lambdas.shape[-1] == D
552    # Check mask
553    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
554    assert mask.shape == data.shape
556    # Compute log pdf
557    lls = -gammaln(data + 1) - lambdas + data * np.log(lambdas)
558    return np.sum(lls * mask, axis=-1)
561def categorical_logpdf(data, logits, mask=None):
562    """
563    Compute the log probability density of a categorical distribution.
564    This will broadcast as long as data and logits have the same
565    (or at least compatible) leading dimensions.
567    Parameters
568    ----------
569    data : array_like (..., D) int (0 <= data < C)
570        The points at which to evaluate the log density
572    lambdas : array_like (..., D, C)
573        The logits of the categorical distribution(s) with C classes
575    mask : array_like (..., D) bool
576        Optional mask indicating which entries in the data are observed
578    Returns
579    -------
580    lps : array_like (...,)
581        Log probabilities under the categorical distribution(s).
582    """
583    D = data.shape[-1]
584    C = logits.shape[-1]
585    assert data.dtype in (int, np.int8, np.int16, np.int32, np.int64)
586    assert np.all((data >= 0) & (data < C))
587    assert logits.shape[-2] == D
589    # Check mask
590    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
591    assert mask.shape == data.shape
593    logits = logits - logsumexp(logits, axis=-1, keepdims=True)      # (..., D, C)
594    x = one_hot(data, C)                                             # (..., D, C)
595    lls = np.sum(x * logits, axis=-1)                                # (..., D)
596    return np.sum(lls * mask, axis=-1)                               # (...,)
599def vonmises_logpdf(data, mus, kappas, mask=None):
600    """
601    Compute the log probability density of a von Mises distribution.
602    This will broadcast as long as data, mus, and kappas have the same
603    (or at least compatible) leading dimensions.
605    Parameters
606    ----------
607    data : array_like (..., D)
608        The points at which to evaluate the log density
610    mus : array_like (..., D)
611        The means of the von Mises distribution(s)
613    kappas : array_like (..., D)
614        The concentration of the von Mises distribution(s)
616    mask : array_like (..., D) bool
617        Optional mask indicating which entries in the data are observed
619    Returns
620    -------
621    lps : array_like (...,)
622        Log probabilities under the von Mises distribution(s).
623    """
624    try:
625        from autograd.scipy.special import i0
626    except:
627        raise Exception("von Mises relies on the function autograd.scipy.special.i0. "
628                        "This is present in the latest Github code, but not on pypi. "
629                        "Please use the Github version of autograd instead.")
631    D = data.shape[-1]
632    assert mus.shape[-1] == D
633    assert kappas.shape[-1] == D
635    # Check mask
636    mask = mask if mask is not None else np.ones_like(data, dtype=bool)
637    assert mask.shape == data.shape
639    ll = kappas * np.cos(data - mus) - np.log(2 * np.pi) - np.log(i0(kappas))
640    return np.sum(ll * mask, axis=-1)