DarthReca commited on
Commit
0a07d7e
·
verified ·
1 Parent(s): b7aafb6

Update modeling_actu.py

Browse files
Files changed (1) hide show
  1. modeling_actu.py +271 -48
modeling_actu.py CHANGED
@@ -1,89 +1,203 @@
 
 
1
  import numpy as np
2
  import timm
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from einops import rearrange
7
  from segmentation_models_pytorch.base import SegmentationHead
8
  from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
9
  from timm.layers.create_act import create_act_layer
 
 
10
 
11
  from .convlstm import ConvLSTM
12
 
13
 
14
- class ACTU(nn.Module):
 
 
15
  def __init__(
16
  self,
17
- in_channels,
18
- kernel_size,
19
- padding,
20
- stride,
21
- backbone: str,
 
22
  bias=True,
23
  batch_first=True,
24
  bidirectional=False,
25
  original_resolution=(256, 256),
26
- act_layer: str = "sigmoid",
27
- n_classes: int = 1,
 
 
 
 
 
 
 
 
28
  **kwargs,
29
  ):
30
- super(ACTU, self).__init__()
31
- self.n_classes = n_classes
 
 
 
32
  self.backbone = backbone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  self.encoder: nn.Module = timm.create_model(
35
- backbone, features_only=True, in_chans=in_channels
36
  )
37
 
38
  with torch.no_grad():
39
- embs = self.encoder.forward(
40
- torch.randn(1, in_channels, *original_resolution)
 
41
  )
42
- embs_shape = [e.shape for e in embs]
 
 
43
 
44
- # The ConvLSTM expects inputs of shape (B, T, feature_dim, H_enc, W_enc)
45
- # We assume the provided ConvLSTM code is available.
46
  self.convlstm = nn.ModuleList(
47
- ConvLSTM(
48
- in_channels=shape[1],
49
- hidden_channels=shape[1],
50
- kernel_size=kernel_size,
51
- padding=padding,
52
- stride=stride,
53
- bias=bias,
54
- batch_first=batch_first,
55
- bidirectional=bidirectional,
56
- )
57
- for shape in embs_shape
 
 
58
  )
59
- # If bidirectional, the hidden representation is concatenated from both directions.
60
- n_upsamples = int(np.log2(original_resolution[0] / embs_shape[-1][-2]))
61
- skip_channels_list = [shape[1] for shape in embs_shape[-(n_upsamples + 1) : -1]]
62
- skip_channels_list = skip_channels_list[::-1] # Reverse the list.
63
- encoder_channels = [e[1] for e in embs_shape]
 
 
 
 
 
 
 
64
 
65
  self.decoder = UnetDecoder(
66
- encoder_channels=[1, *encoder_channels],
67
- decoder_channels=encoder_channels[::-1],
68
- n_blocks=len(encoder_channels),
69
  )
 
70
  self.seg_head = nn.Sequential(
71
  SegmentationHead(
72
- in_channels=encoder_channels[0],
73
- out_channels=n_classes,
74
  ),
75
- create_act_layer(act_layer, inplace=True),
76
  )
77
- self.encoder_channels = encoder_channels
78
- self.embs_shape = embs_shape
79
 
80
- def forward(self, x: torch.Tensor, **kwargs):
81
- size = x.size()[-2:]
82
- # Process each time step through the encoder.
83
- x = self._encode_images(x)
84
- # Pass the encoded sequence through the ConvLSTM.
85
- x = self._encode_timeseries(x)
86
- return self._decode(x, size=size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
89
  B = x.size(0)
@@ -107,3 +221,112 @@ class ACTU(nn.Module):
107
  trend_map, size=size, mode="bilinear", align_corners=False
108
  )
109
  return trend_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
  import numpy as np
4
  import timm
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
  from segmentation_models_pytorch.base import SegmentationHead
10
  from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
11
  from timm.layers.create_act import create_act_layer
12
+ from transformers import PretrainedConfig, PreTrainedModel
13
+ from transformers.modeling_outputs import SemanticSegmenterOutput
14
 
15
  from .convlstm import ConvLSTM
16
 
17
 
18
+ class ACTUConfig(PretrainedConfig):
19
+ model_type = "actu"
20
+
21
  def __init__(
22
  self,
23
+ # Base ACTU parameters
24
+ in_channels: int = 3,
25
+ kernel_size: tuple[int, int] = (3, 3),
26
+ padding="same",
27
+ stride=(1, 1),
28
+ backbone="resnet34",
29
  bias=True,
30
  batch_first=True,
31
  bidirectional=False,
32
  original_resolution=(256, 256),
33
+ act_layer="sigmoid",
34
+ n_classes=1,
35
+ # Variant control parameters
36
+ use_dem_input: bool = False,
37
+ use_climate_branch: bool = False,
38
+ # Climate branch parameters
39
+ climate_seq_len=5,
40
+ climate_input_dim=6,
41
+ lstm_hidden_dim=128,
42
+ num_lstm_layers=1,
43
  **kwargs,
44
  ):
45
+ super().__init__(**kwargs)
46
+ self.in_channels = in_channels
47
+ self.kernel_size = kernel_size
48
+ self.padding = padding
49
+ self.stride = stride
50
  self.backbone = backbone
51
+ self.bias = bias
52
+ self.batch_first = batch_first
53
+ self.bidirectional = bidirectional
54
+ self.original_resolution = original_resolution
55
+ self.act_layer = act_layer
56
+ self.n_classes = n_classes
57
+
58
+ # Parameters to control variants
59
+ self.use_dem_input = use_dem_input
60
+ self.use_climate_branch = use_climate_branch
61
+ self.climate_seq_len = climate_seq_len
62
+ self.climate_input_dim = climate_input_dim
63
+ self.lstm_hidden_dim = lstm_hidden_dim
64
+ self.num_lstm_layers = num_lstm_layers
65
+
66
+ # Adjust in_channels if DEM is used
67
+ if self.use_dem_input:
68
+ self.in_channels += 1
69
+
70
+
71
+ class ACTUForImageSegmentation(PreTrainedModel):
72
+ config_class = ACTUConfig
73
+
74
+ def __init__(self, config: ACTUConfig):
75
+ super().__init__(config)
76
+ self.config = config
77
 
78
  self.encoder: nn.Module = timm.create_model(
79
+ config.backbone, features_only=True, in_chans=config.in_channels
80
  )
81
 
82
  with torch.no_grad():
83
+ dummy_input_channels = config.in_channels
84
+ dummy_input = torch.randn(
85
+ 1, dummy_input_channels, *config.original_resolution, device=self.device
86
  )
87
+ embs = self.encoder(dummy_input)
88
+ self.embs_shape = [e.shape for e in embs]
89
+ self.encoder_channels = [e[1] for e in self.embs_shape]
90
 
 
 
91
  self.convlstm = nn.ModuleList(
92
+ [
93
+ ConvLSTM(
94
+ in_channels=shape[1],
95
+ hidden_channels=shape[1],
96
+ kernel_size=config.kernel_size,
97
+ padding=config.padding,
98
+ stride=config.stride,
99
+ bias=config.bias,
100
+ batch_first=config.batch_first,
101
+ bidirectional=config.bidirectional,
102
+ )
103
+ for shape in self.embs_shape
104
+ ]
105
  )
106
+
107
+ if self.config.use_climate_branch:
108
+ self.climate_branch = ClimateBranchLSTM(
109
+ output_shapes=[e[1:] for e in self.embs_shape],
110
+ lstm_hidden_dim=config.lstm_hidden_dim,
111
+ climate_seq_len=config.climate_seq_len,
112
+ climate_input_dim=config.climate_input_dim,
113
+ num_lstm_layers=config.num_lstm_layers,
114
+ )
115
+ self.fusers = nn.ModuleList(
116
+ GatedFusion(enc, enc) for enc in self.encoder_channels
117
+ )
118
 
119
  self.decoder = UnetDecoder(
120
+ encoder_channels=[1] + self.encoder_channels,
121
+ decoder_channels=self.encoder_channels[::-1],
122
+ n_blocks=len(self.encoder_channels),
123
  )
124
+
125
  self.seg_head = nn.Sequential(
126
  SegmentationHead(
127
+ in_channels=self.encoder_channels[0],
128
+ out_channels=config.n_classes,
129
  ),
130
+ create_act_layer(config.act_layer, inplace=True),
131
  )
 
 
132
 
133
+ def forward(
134
+ self,
135
+ pixel_values: torch.Tensor,
136
+ climate: torch.Tensor = None,
137
+ dem: torch.Tensor = None,
138
+ labels: torch.Tensor = None,
139
+ **kwargs,
140
+ ) -> SemanticSegmenterOutput:
141
+ b, t = pixel_values.shape[:2]
142
+ original_size = pixel_values.shape[-2:]
143
+
144
+ # Handle DEM input
145
+ if self.config.use_dem_input:
146
+ if dem is None:
147
+ raise ValueError(
148
+ "DEM tensor must be provided when use_dem_input is True."
149
+ )
150
+ dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t)
151
+ pixel_values = torch.cat([pixel_values, dem_repeated], dim=2)
152
+
153
+ # 1. Encode images per time step
154
+ encoded_sequence = self._encode_images(pixel_values)
155
+
156
+ # 2. Handle Climate Branch Fusion
157
+ if self.config.use_climate_branch:
158
+ if climate is None:
159
+ raise ValueError(
160
+ "Climate tensor must be provided when use_climate_branch is True."
161
+ )
162
+
163
+ climate_features = self.climate_branch(climate)
164
+
165
+ # Reshape for fusion
166
+ encoded_sequence_reshaped = [
167
+ rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence
168
+ ]
169
+ climate_features_reshaped = [
170
+ rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features
171
+ ]
172
+
173
+ # Fuse features
174
+ fused_features = [
175
+ fuser(img, clim)
176
+ for fuser, img, clim in zip(
177
+ self.fusers, encoded_sequence_reshaped, climate_features_reshaped
178
+ )
179
+ ]
180
+
181
+ # Reshape back to sequence
182
+ encoded_sequence = [
183
+ rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features
184
+ ]
185
+
186
+ # 3. Process sequence with ConvLSTM
187
+ temporal_features = self._encode_timeseries(encoded_sequence)
188
+
189
+ # 4. Decode to get the segmentation map
190
+ logits = self._decode(temporal_features, size=original_size)
191
+
192
+ loss = None
193
+ if labels is not None:
194
+ loss_fct = nn.CrossEntropyLoss()
195
+ loss = loss_fct(logits, labels.float().unsqueeze(1))
196
+
197
+ return SemanticSegmenterOutput(
198
+ loss=loss,
199
+ logits=logits,
200
+ )
201
 
202
  def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]:
203
  B = x.size(0)
 
221
  trend_map, size=size, mode="bilinear", align_corners=False
222
  )
223
  return trend_map
224
+
225
+
226
+ class ClimateBranchLSTM(nn.Module):
227
+ """
228
+ Processes climate time series data using an LSTM.
229
+ Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5)
230
+ Output shape: (B, T, output_dim) -> e.g., (B, 5, 128)
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ output_shapes: list[tuple[int, int, int]],
236
+ climate_input_dim=5,
237
+ climate_seq_len=6,
238
+ lstm_hidden_dim=64,
239
+ num_lstm_layers=1,
240
+ ):
241
+ super().__init__()
242
+ self.climate_seq_len = climate_seq_len
243
+ self.climate_input_dim = climate_input_dim
244
+ self.lstm_hidden_dim = lstm_hidden_dim
245
+ self.num_lstm_layers = num_lstm_layers
246
+ self.proj_dim = 128
247
+ self.output_shapes = output_shapes
248
+
249
+ self.lstm = nn.LSTM(
250
+ input_size=climate_input_dim,
251
+ hidden_size=lstm_hidden_dim,
252
+ num_layers=num_lstm_layers,
253
+ batch_first=True, # Crucial: expects input shape (batch, seq_len, features)
254
+ dropout=0.3 if num_lstm_layers > 1 else 0,
255
+ bidirectional=False,
256
+ )
257
+
258
+ # Linear layer to project LSTM output to the desired final dimension
259
+ self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim)
260
+
261
+ self.upsamples = nn.ModuleList(
262
+ _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes
263
+ )
264
+
265
+ def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]:
266
+ # climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5)
267
+ B_img, B_cli, T, C = climate_data.shape
268
+
269
+ # Reshape for LSTM: Treat each sequence independently
270
+ lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C")
271
+
272
+ # Pass through LSTM
273
+ _, (hidden, _) = self.lstm.forward(lstm_input)
274
+ # Get the last layer's hidden state
275
+ last_hidden = (
276
+ hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1]
277
+ )
278
+ if last_hidden.ndim == 3:
279
+ last_hidden = hidden.mean(dim=0)
280
+
281
+ # Pass the final hidden state through the fully connected layer(s) and upsample
282
+ climate_features = self.fc(last_hidden)
283
+ climate_features = rearrange(climate_features, "b c -> b c 1 1")
284
+ climate_features = [
285
+ rearrange(
286
+ u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli
287
+ )
288
+ for u in self.upsamples
289
+ ]
290
+
291
+ return climate_features
292
+
293
+
294
+ class GatedFusion(nn.Module):
295
+ def __init__(self, img_channels, clim_channels):
296
+ super().__init__()
297
+ self.gate = nn.Sequential(
298
+ nn.Sequential(
299
+ nn.Conv2d(
300
+ img_channels + clim_channels, img_channels, kernel_size=3, padding=1
301
+ ),
302
+ nn.ReLU(inplace=True),
303
+ nn.Conv2d(img_channels, img_channels, kernel_size=1),
304
+ nn.Sigmoid(), # Gate values between 0 and 1
305
+ )
306
+ )
307
+
308
+ def forward(self, img_feat, clim_feat):
309
+ gate = self.gate(torch.cat([img_feat, clim_feat], dim=1))
310
+ return gate * img_feat + (1 - gate) * clim_feat
311
+
312
+
313
+ def _build_upsampler(
314
+ in_channels: int, target_channels: int, target_h: int
315
+ ) -> nn.Sequential:
316
+ layers = []
317
+ current_h = 1
318
+
319
+ # Expand to target channels early (e.g., 1x1 → 1x1 with target_channels)
320
+ layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()]
321
+
322
+ # Upsample spatially to target_h
323
+ while current_h < target_h:
324
+ next_h = min(current_h * 2, target_h)
325
+ layers += [
326
+ nn.Upsample(scale_factor=2, mode="nearest"),
327
+ nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1),
328
+ nn.GELU(),
329
+ ]
330
+ current_h = next_h
331
+
332
+ return nn.Sequential(*layers)