Commit 00066390 authored by Jean-Marie Lepioufle's avatar Jean-Marie Lepioufle
Browse files

typo

parent 461d5caf
......@@ -77,8 +77,8 @@ class mlp_cnn(torch.nn.Module):
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
def cnn_block(self,x):
out = self.cnn_layer1(x)
def cnn_block(self,x_img):
out = self.cnn_layer1(x_img)
out = self.cnn_layer2(out)
out = self.cnn_layer3(out)
out = self.cnn_layer4(out)
......@@ -94,6 +94,7 @@ class mlp_cnn(torch.nn.Module):
return out
def forward(self, x_f,x_img):
# features x[0]; images x[1]
x_f = self.mlp(x_f)
x_img = self.cnn_block(x_img)
x_img = torch.flatten(x_img,start_dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment