@@ -198,209 +198,3 @@ def forward(self, x: torch.Tensor, style: torch.Tensor = None) -> torch.Tensor:
198198 x = x + identity
199199
200200 return x
201-
202-
203- ##############
204- ##############
205- ##############
206- # from typing import List, Tuple
207-
208- # import torch
209- # import torch.nn as nn
210-
211- # from .conv_block import ConvBlock
212- # from .misc_modules import ChannelPool
213-
214- # __all__ = ["ConvLayer"]
215-
216-
217- # class ConvLayer(nn.Module):
218- # def __init__(
219- # self,
220- # in_channels: int,
221- # out_channels: int,
222- # n_blocks: int = 2,
223- # layer_residual: bool = False,
224- # short_skip: str = "residual",
225- # style_channels: int = None,
226- # expand_ratios: Tuple[float, ...] = (1.0, 1.0),
227- # block_types: Tuple[str, ...] = ("basic", "basic"),
228- # normalizations: Tuple[str, ...] = ("bn", "bn"),
229- # activations: Tuple[str, ...] = ("relu", "relu"),
230- # convolutions: Tuple[str, ...] = ("conv", "conv"),
231- # kernel_sizes: Tuple[int, ...] = (3, 3),
232- # groups: Tuple[int, ...] = (1, 1),
233- # biases: Tuple[bool, ...] = (True, True),
234- # preactivates: Tuple[bool, ...] = (False, False),
235- # attentions: Tuple[str, ...] = (None, None),
236- # preattends: Tuple[bool, ...] = (False, False),
237- # use_styles: Tuple[bool, ...] = (False, False),
238- # **kwargs,
239- # ) -> None:
240- # """Chain conv-blocks in a ModuleDict to compose a full layer.
241-
242- # Optional:
243- # - add a style vector to the output at the end of each conv block(Cellpose)
244-
245- # Parameters
246- # ----------
247- # in_channels : int
248- # Number of input channels.
249- # out_channels : int
250- # Number of output channels.
251- # n_blocks : int, default=2
252- # Number of ConvBlocks used in this layer.
253- # layer_residual : bool, default=False
254- # Apply a layer level residual skip. I.e x + layer(x). NOTE: residual
255- # skips can be also applied inside the ConvBlocks, so this is justextra.
256- # style_channels : int, default=None
257- # Number of style vector channels. If None, style vectors are ignored.
258- # short_skip : str, default="residual"
259- # The name of the short skip method. One of: "residual", "dense","basic"
260- # expand_ratios : Tuple[float, ...], default=(1.0, 1.0):
261- # Expansion/Squeeze ratios for the out channels of each conv block.
262- # block_types : Tuple[str, ...], default=("basic", "basic")
263- # The name of the conv-blocks. Length of the tuple has toequal`n_blocks`
264- # One of: "basic". "mbconv", "fmbconv" "dws", "bottleneck".
265- # normalizations : Tuple[str, ...], default=("bn", "bn"):
266- # Normalization methods. One of: "bn", "bcn", "gn", "in", "ln", "lrn"
267- # activations : Tuple[str, ...], default=("relu", "relu")
268- # Activation methods. One of: "mish", "swish", "relu", "relu6", "rrelu",
269- # "selu", "celu", "gelu", "glu", "tanh", "sigmoid", "silu", "prelu",
270- # "leaky-relu", "elu", "hardshrink", "tanhshrink", "hardsigmoid"
271- # convolutions : Tuple[str, ...], default=("conv", "conv")
272- # The convolution method. One of: "conv", "wsconv", "scaled_wsconv"
273- # preactivates : Tuple[bool, ...], default=(False, False)
274- # Pre-activations flags for the conv-blocks.
275- # kernel_sizes : Tuple[int, ...], default=(3, 3)
276- # The size of the convolution kernels in each conv block.
277- # groups : int, default=(1, 1)
278- # Number of groups for the kernels in each convolution blocks.
279- # biases : Tuple[bool, ...], default=(True, True)
280- # Include bias terms in the convolution blocks.
281- # attentions : Tuple[str, ...], default=(None, None)
282- # Attention methods. One of: "se", "scse", "gc", "eca", None
283- # preattends : Tuple[bool, ...], default=(False, False)
284- # If True, Attention is applied at the beginning of forward pass.
285- # use_styles : Tuple[bool, ...], default=(False, False)
286- # If True and `style_channels` is not None, adds a style vec to the
287- # ConvBlock outputs.
288-
289- # Raises
290- # ------
291- # ValueError:
292- # If lengths of the tuple arguments are not equal to `n_blocks`.
293- # """
294- # super().__init__()
295- # self.layer_residual = layer_residual
296- # self.short_skip = short_skip
297- # self.in_channels = in_channels
298-
299- # illegal_args = [
300- # (k, a)
301- # for k, a in locals().items()
302- # if isinstance(a, tuple) and len(a) != n_blocks
303- # ]
304-
305- # if illegal_args:
306- # raise ValueError(
307- # f"All the tuple-arg lengths need to be equalto`n_blocks`={n_blocks}. "
308- # f"Illegal args: {illegal_args}"
309- # )
310-
311- # self.conv_blocks = nn.ModuleDict()
312- # blocks = list(range(n_blocks))
313- # for i in blocks:
314- # out = int(out_channels * expand_ratios[i])
315-
316- # conv_block = ConvBlock(
317- # name=block_types[i],
318- # in_channels=in_channels,
319- # out_channels=out,
320- # style_channels=style_channels,
321- # short_skip=short_skip,
322- # kernel_size=kernel_sizes[i],
323- # groups=groups[i],
324- # bias=biases[i],
325- # normalization=normalizations[i],
326- # convolution=convolutions[i],
327- # activation=activations[i],
328- # attention=attentions[i],
329- # preactivate=preactivates[i],
330- # preattend=preattends[i],
331- # use_style=use_styles[i],
332- # **kwargs,
333- # )
334- # self.conv_blocks[f"{short_skip}_{block_types[i]}_{i + 1}"] = conv_block
335-
336- # if short_skip == "dense":
337- # in_channels += conv_block.out_channels
338- # else:
339- # in_channels = conv_block.out_channels
340-
341- # self.out_channels = conv_block.out_channels
342-
343- # if short_skip == "dense":
344- # self.transition = ConvBlock(
345- # name="basic",
346- # in_channels=in_channels,
347- # short_skip="basic",
348- # out_channels=out_channels,
349- # same_padding=False,
350- # bias=False,
351- # kernel_size=1,
352- # convolution=conv_block.block.conv_choice,
353- # normalization=normalizations[-1],
354- # activation=activations[-1],
355- # preactivate=preactivates[-1],
356- # )
357- # self.out_channels = self.transition.out_channels
358-
359- # self.downsample = None
360- # if layer_residual and self.in_channels != self.out_channels:
361- # self.downsample = ChannelPool(
362- # in_channels=self.in_channels,
363- # out_channels=self.out_channels,
364- # convolution=convolutions[-1],
365- # normalization=normalizations[-1],
366- # )
367-
368- # def forward_features_dense(
369- # self, init_features: List[torch.Tensor], style: torch.Tensor = None
370- # ) -> torch.Tensor:
371- # """Dense forward pass."""
372- # features = [init_features]
373- # for conv_block in self.conv_blocks.values():
374- # new_features = conv_block(features, style)
375- # features.append(new_features)
376-
377- # x = torch.cat(features, 1)
378- # x = self.transition(x)
379-
380- # return x
381-
382- # def forward_features(
383- # self, x: torch.Tensor, style: torch.Tensor = None
384- # ) -> torch.Tensor:
385- # """Regular forward pass."""
386- # for conv_block in self.conv_blocks.values():
387- # x = conv_block(x, style)
388-
389- # return x
390-
391- # def forward(self, x: torch.Tensor, style: torch.Tensor = None) -> torch.Tensor:
392- # """Forward pass of the conv-layer."""
393- # if self.layer_residual:
394- # identity = x
395- # if self.downsample is not None:
396- # identity = self.downsample(x)
397-
398- # if self.short_skip == "dense":
399- # x = self.forward_features_dense(x, style)
400- # else:
401- # x = self.forward_features(x, style)
402-
403- # if self.layer_residual:
404- # x = x + identity
405-
406- # return x
0 commit comments