summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenedek Rozemberczki <benedek.rozemberczki@gmail.com>2022-09-04 15:50:35 +0100
committerGitHub <noreply@github.com>2022-09-04 15:50:35 +0100
commit9d4a7add6d87839d6f0da59e913ebfb82612b2d2 (patch)
tree9561092f42695f7866ca42146964edcb24cc66b2
parent57006a47b3f0286343559cbc2496cbff263b941f (diff)
parentcbb8ba51c3f7485450f3e50858228e530da66ca0 (diff)
downloadpytorch_geometric_temporal-9d4a7add6d87839d6f0da59e913ebfb82612b2d2.tar.gz
pytorch_geometric_temporal-9d4a7add6d87839d6f0da59e913ebfb82612b2d2.tar.xz
Merge pull request #187 from h3dema/patch-1
change on _set_hidden_state()
-rw-r--r--torch_geometric_temporal/nn/recurrent/temporalgcn.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/torch_geometric_temporal/nn/recurrent/temporalgcn.py b/torch_geometric_temporal/nn/recurrent/temporalgcn.py
index d84c81c..f9eaa34 100644
--- a/torch_geometric_temporal/nn/recurrent/temporalgcn.py
+++ b/torch_geometric_temporal/nn/recurrent/temporalgcn.py
@@ -144,7 +144,9 @@ class TGCN2(torch.nn.Module):
add_self_loops (bool): Adding self-loops for smoothing. Default is True.
"""
- def __init__(self, in_channels: int, out_channels: int, batch_size: int, improved: bool = False, cached: bool = False,
+ def __init__(self, in_channels: int, out_channels: int,
+ batch_size: int, # this entry is unnecessary, kept only for backward compatibility
+ improved: bool = False, cached: bool = False,
add_self_loops: bool = True):
super(TGCN2, self).__init__()
@@ -153,7 +155,7 @@ class TGCN2(torch.nn.Module):
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
- self.batch_size = batch_size
+ self.batch_size = batch_size # not needed
self._create_parameters_and_layers()
def _create_update_gate_parameters_and_layers(self):
@@ -178,7 +180,8 @@ class TGCN2(torch.nn.Module):
def _set_hidden_state(self, X, H):
if H is None:
- H = torch.zeros(self.batch_size,X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
+ # can infer batch_size from X.shape, because X is [B, N, F]
+ H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
return H
def _calculate_update_gate(self, X, edge_index, edge_weight, H):