diff options
author | Benedek Rozemberczki <benedek.rozemberczki@gmail.com> | 2022-09-04 15:50:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-04 15:50:35 +0100 |
commit | 9d4a7add6d87839d6f0da59e913ebfb82612b2d2 (patch) | |
tree | 9561092f42695f7866ca42146964edcb24cc66b2 | |
parent | 57006a47b3f0286343559cbc2496cbff263b941f (diff) | |
parent | cbb8ba51c3f7485450f3e50858228e530da66ca0 (diff) | |
download | pytorch_geometric_temporal-9d4a7add6d87839d6f0da59e913ebfb82612b2d2.tar.gz pytorch_geometric_temporal-9d4a7add6d87839d6f0da59e913ebfb82612b2d2.tar.xz |
Merge pull request #187 from h3dema/patch-1
change on _set_hidden_state()
-rw-r--r-- | torch_geometric_temporal/nn/recurrent/temporalgcn.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/torch_geometric_temporal/nn/recurrent/temporalgcn.py b/torch_geometric_temporal/nn/recurrent/temporalgcn.py index d84c81c..f9eaa34 100644 --- a/torch_geometric_temporal/nn/recurrent/temporalgcn.py +++ b/torch_geometric_temporal/nn/recurrent/temporalgcn.py @@ -144,7 +144,9 @@ class TGCN2(torch.nn.Module): add_self_loops (bool): Adding self-loops for smoothing. Default is True. """ - def __init__(self, in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False, + def __init__(self, in_channels: int, out_channels: int, + batch_size: int, # this entry is unnecessary, kept only for backward compatibility + improved: bool = False, cached: bool = False, add_self_loops: bool = True): super(TGCN2, self).__init__() @@ -153,7 +155,7 @@ class TGCN2(torch.nn.Module): self.improved = improved self.cached = cached self.add_self_loops = add_self_loops - self.batch_size = batch_size + self.batch_size = batch_size # not needed self._create_parameters_and_layers() def _create_update_gate_parameters_and_layers(self): @@ -178,7 +180,8 @@ class TGCN2(torch.nn.Module): def _set_hidden_state(self, X, H): if H is None: - H = torch.zeros(self.batch_size,X.shape[1], self.out_channels).to(X.device) #(b, 207, 32) + # can infer batch_size from X.shape, because X is [B, N, F] + H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32) return H def _calculate_update_gate(self, X, edge_index, edge_weight, H): |