From 609c2332e1d0b629ec81e8e8fc708c4a94256294 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Aug 2021 16:16:13 +0000 Subject: [PATCH 1/7] add nonprojective entropy implementation --- torch_struct/distributions.py | 49 ++++++++++++++--------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index b6f233e..4e814b0 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -61,11 +61,7 @@ def log_prob(self, value): d = value.dim() batch_dims = range(d - len(self.event_shape)) - v = self._struct().score( - self.log_potentials, - value.type_as(self.log_potentials), - batch_dims=batch_dims, - ) + v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,) return v - self.partition @@ -91,9 +87,7 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) def kl(self, other): """ @@ -105,9 +99,7 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) @lazy_property def max(self): @@ -140,9 +132,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -155,9 +145,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -186,9 +174,7 @@ def count(self): def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( - self.log_potentials, self.lengths - ) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) return st_gumbel # @constraints.dependent_property @@ -219,9 +205,7 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -301,9 +285,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -440,9 +422,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): @@ -504,4 +484,13 @@ def argmax(self): @lazy_property def entropy(self): - pass + """ + Compute entropy efficiently using arc-factorization property. + + See implementation notebook [here](https://colab.research.google.com/drive/1iUr78J901lMBlGVYpxSrRRmNJYX4FWyg?usp=sharing) + """ + logZ = self.partition + p = self.marginals + phi = self.log_potentials + H = logZ - (p * phi).reshape(phi.shape[0], -1).sum(-1) + return H From d1d19f9e790ff64963462638e644f4cedebcf65f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 30 Aug 2021 17:03:25 +0000 Subject: [PATCH 2/7] add nonprojective entropy implementation --- torch_struct/distributions.py | 39 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 4e814b0..3c66c41 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -61,7 +61,11 @@ def log_prob(self, value): d = value.dim() batch_dims = range(d - len(self.event_shape)) - v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,) + v = self._struct().score( + self.log_potentials, + value.type_as(self.log_potentials), + batch_dims=batch_dims, + ) return v - self.partition @@ -87,7 +91,9 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(CrossEntropySemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) def kl(self, other): """ @@ -99,7 +105,9 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(KLDivergenceSemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) @lazy_property def max(self): @@ -132,7 +140,9 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).sum( + self.log_potentials, self.lengths, _raw=True + ) def topk(self, k): r""" @@ -145,7 +155,9 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self): @@ -174,7 +186,9 @@ def count(self): def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( + self.log_potentials, self.lengths + ) return st_gumbel # @constraints.dependent_property @@ -205,7 +219,9 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) + sample = self._struct(MultiSampledSemiring).marginals( + self.log_potentials, lengths=self.lengths + ) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -285,7 +301,9 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) + return self.struct( + sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap + ) class HMM(StructDistribution): @@ -422,7 +440,9 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) + super(StructDistribution, self).__init__( + batch_shape=batch_shape, event_shape=event_shape + ) class NonProjectiveDependencyCRF(StructDistribution): @@ -488,6 +508,7 @@ def entropy(self): Compute entropy efficiently using arc-factorization property. See implementation notebook [here](https://colab.research.google.com/drive/1iUr78J901lMBlGVYpxSrRRmNJYX4FWyg?usp=sharing) + for derivation. """ logZ = self.partition p = self.marginals From cad728fa11be0a0e2524a60e410449fd7c86bac8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 1 Sep 2021 18:03:55 +0000 Subject: [PATCH 3/7] update docstring --- torch_struct/distributions.py | 54 +++++++++++++++-------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 3c66c41..a6444e7 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -61,11 +61,7 @@ def log_prob(self, value): d = value.dim() batch_dims = range(d - len(self.event_shape)) - v = self._struct().score( - self.log_potentials, - value.type_as(self.log_potentials), - batch_dims=batch_dims, - ) + v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,) return v - self.partition @@ -91,9 +87,7 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) def kl(self, other): """ @@ -105,9 +99,7 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) @lazy_property def max(self): @@ -140,9 +132,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -155,9 +145,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -186,9 +174,7 @@ def count(self): def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( - self.log_potentials, self.lengths - ) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) return st_gumbel # @constraints.dependent_property @@ -219,9 +205,7 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -301,9 +285,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -440,9 +422,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): @@ -507,8 +487,20 @@ def entropy(self): """ Compute entropy efficiently using arc-factorization property. - See implementation notebook [here](https://colab.research.google.com/drive/1iUr78J901lMBlGVYpxSrRRmNJYX4FWyg?usp=sharing) - for derivation. + Algorithm derivation: + ..math:: + {{ + \begin{align} + H[p] &= E_{p(T)}[-\log p(T)]\\ + &= -E_{p(T)}\big[ \log [\frac{1}{Z} \prod\limits_{(i,j) \in T} \exp\{\phi_{i,j}\}] \big]\\ + &= -E_{p(T)}\big[ \sum\limits_{(i,j) \in T} \phi_{i,j} - \log Z \big]\\ + &= \log Z -E_{p(T)}\big[\sum\limits_{(i,j) \in A} 1\{(i,j) \in T\} \phi_{i,j}\big]\\ + &= \log Z - \sum\limits_{(i,j) \in A} p\big((i,j) \in T\big) \phi_{i,j} + \end{align} + }} + + Returns: + entropy (*batch_shape) """ logZ = self.partition p = self.marginals From b3ea558914a280809e42407da0221c4331bf541b Mon Sep 17 00:00:00 2001 From: Sasha Rush Date: Fri, 3 Sep 2021 09:19:23 -0400 Subject: [PATCH 4/7] Update distributions.py --- torch_struct/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index a6444e7..aedffcf 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -484,7 +484,7 @@ def argmax(self): @lazy_property def entropy(self): - """ + r""" Compute entropy efficiently using arc-factorization property. Algorithm derivation: From 3e17a5036c746d481c4a7f341ce382cd87a1217e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 14 Sep 2021 15:37:40 +0000 Subject: [PATCH 5/7] refactor central entropy, xent, kl using marginals instead of semiring --- torch_struct/distributions.py | 114 +++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 43 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index a6444e7..d17e5bf 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -10,9 +10,6 @@ from .semirings import ( LogSemiring, MaxSemiring, - EntropySemiring, - CrossEntropySemiring, - KLDivergenceSemiring, MultiSampledSemiring, KMaxSemiring, StdSemiring, @@ -61,20 +58,39 @@ def log_prob(self, value): d = value.dim() batch_dims = range(d - len(self.event_shape)) - v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,) + v = self._struct().score( + self.log_potentials, + value.type_as(self.log_potentials), + batch_dims=batch_dims, + ) return v - self.partition @lazy_property def entropy(self): """ - Compute entropy for distribution :math:`H[z]`. + Compute entropy for distribution :math:`H[p]`. + + Algorithm derivation: + ..math:: + {{ + \begin{align} + H[p] &= E_{p(z)}[-\log p(z)]\\ + &= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\ + &= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\ + &= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\ + &= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c) + \end{align} + }} Returns: entropy (*batch_shape*) """ - - return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths) + logZ = self.partition + p = self.marginals + phi = self.log_potentials + Hz = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1) + return Hz def cross_entropy(self, other): """ @@ -86,8 +102,11 @@ def cross_entropy(self, other): Returns: cross entropy (*batch_shape*) """ - - return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + logZ = other.partition + p = self.marginals + phi_q = other.log_potentials + Hq = logZ - (p * phi_q).reshape(p.shape[0], -1).sum(-1) + return Hq def kl(self, other): """ @@ -97,9 +116,15 @@ def kl(self, other): other : Comparison distribution Returns: - cross entropy (*batch_shape*) + kl divergence (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + logZp = self.partition + logZq = other.partition + p = self.marginals + phi_p = self.log_potentials + phi_q = other.log_potentials + KLpq = (p * (phi_p - phi_q)).reshape(p.shape[0], -1).sum(-1) - logZp + logZq + return KLpq @lazy_property def max(self): @@ -132,7 +157,9 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).sum( + self.log_potentials, self.lengths, _raw=True + ) def topk(self, k): r""" @@ -145,7 +172,9 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self): @@ -174,7 +203,9 @@ def count(self): def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( + self.log_potentials, self.lengths + ) return st_gumbel # @constraints.dependent_property @@ -205,7 +236,9 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) + sample = self._struct(MultiSampledSemiring).marginals( + self.log_potentials, lengths=self.lengths + ) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -285,7 +318,9 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) + return self.struct( + sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap + ) class HMM(StructDistribution): @@ -422,7 +457,9 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) + super(StructDistribution, self).__init__( + batch_shape=batch_shape, event_shape=event_shape + ) class NonProjectiveDependencyCRF(StructDistribution): @@ -451,6 +488,23 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=False): super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args) self.multiroot = multiroot + def log_prob(self, value): + """ + Compute log probability over values :math:`p(z)`. + + Parameters: + value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*) + + Returns: + log_probs (*sample_shape x batch_shape*) + """ + s = value.shape + # assumes values do not have any 1s outside of the lengths + value_total_log_potentials = ( + (value * self.log_potentials.expand(s)).reshape(*s[:-2], -1).sum(-1) + ) + return value_total_log_potentials - self.partition + @lazy_property def marginals(self): """ @@ -481,29 +535,3 @@ def argmax(self): (Currently not implemented) """ pass - - @lazy_property - def entropy(self): - """ - Compute entropy efficiently using arc-factorization property. - - Algorithm derivation: - ..math:: - {{ - \begin{align} - H[p] &= E_{p(T)}[-\log p(T)]\\ - &= -E_{p(T)}\big[ \log [\frac{1}{Z} \prod\limits_{(i,j) \in T} \exp\{\phi_{i,j}\}] \big]\\ - &= -E_{p(T)}\big[ \sum\limits_{(i,j) \in T} \phi_{i,j} - \log Z \big]\\ - &= \log Z -E_{p(T)}\big[\sum\limits_{(i,j) \in A} 1\{(i,j) \in T\} \phi_{i,j}\big]\\ - &= \log Z - \sum\limits_{(i,j) \in A} p\big((i,j) \in T\big) \phi_{i,j} - \end{align} - }} - - Returns: - entropy (*batch_shape) - """ - logZ = self.partition - p = self.marginals - phi = self.log_potentials - H = logZ - (p * phi).reshape(phi.shape[0], -1).sum(-1) - return H From d3fa0d617808844bd3306e8e2de9866609fbe204 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 14 Sep 2021 17:41:06 +0000 Subject: [PATCH 6/7] attempt to fix docstr lint error and derivation --- torch_struct/distributions.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 30ac4a7..c1b5930 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -69,20 +69,16 @@ def log_prob(self, value): @lazy_property def entropy(self): - """ + r""" Compute entropy for distribution :math:`H[p]`. Algorithm derivation: ..math:: - {{ - \begin{align} H[p] &= E_{p(z)}[-\log p(z)]\\ - &= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\ - &= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\ - &= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\ - &= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c) - \end{align} - }} + &= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\ + &= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\ + &= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\ + &= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c) Returns: entropy (*batch_shape*) @@ -90,8 +86,8 @@ def entropy(self): logZ = self.partition p = self.marginals phi = self.log_potentials - Hz = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1) - return Hz + Hp = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1) + return Hp def cross_entropy(self, other): """ @@ -536,3 +532,14 @@ def argmax(self): (Currently not implemented) """ pass + + def rsample(self, temp=1.0): + """ Do a marginal reparameterization sample to get a soft approximate sample. + """ + noise = (torch.distributions.Gumbel(0, 1).sample(self.log_potentials.shape)).to( + self.log_potentials.device + ) + noised_log_potentials = (self.log_potentials + noise) / temp + return NonProjectiveDependencyCRF( + noised_log_potentials, self.lengths, multiroot=self.multiroot + ).marginals From e2cdc415f5313d76456eb40c9a9f065e558c5717 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 14 Sep 2021 17:44:50 +0000 Subject: [PATCH 7/7] attempt to fix docstr lint error and derivation --- torch_struct/distributions.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index c1b5930..aa85cfb 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -532,14 +532,3 @@ def argmax(self): (Currently not implemented) """ pass - - def rsample(self, temp=1.0): - """ Do a marginal reparameterization sample to get a soft approximate sample. - """ - noise = (torch.distributions.Gumbel(0, 1).sample(self.log_potentials.shape)).to( - self.log_potentials.device - ) - noised_log_potentials = (self.log_potentials + noise) / temp - return NonProjectiveDependencyCRF( - noised_log_potentials, self.lengths, multiroot=self.multiroot - ).marginals