@@ -4469,6 +4469,7 @@ def _get_same_padding_1d(input_size, kernel_size, stride):
44694469 pad_left, pad_right = _get_same_padding_1d(input_size, kernel_size, stride)
44704470 padding = pad_left # 对称填充
44714471 if pad_left != pad_right: # 非对称填充
4472+ # TODO(zrr1999) maybe mode="replicate"
44724473 x = torch.nn.functional.pad(x, (pad_left, pad_right))
44734474 padding = 0
44744475elif isinstance(padding, (list, tuple)):
@@ -4508,7 +4509,7 @@ def _get_same_padding_2d(input_size, kernel_size, stride):
45084509 pad_h, pad_w = _get_same_padding_2d(input_size, kernel_size, stride)
45094510 padding = (pad_h[0], pad_w[0]) # 对称填充
45104511 if pad_h[0] != pad_h[1] or pad_w[0] != pad_w[1]: # 非对称填充
4511- x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1]))
4512+ x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1]), mode="replicate" )
45124513 padding = 0
45134514elif isinstance(padding, (list, tuple)):
45144515 if len(padding) == 2: # [pad_height, pad_width]
@@ -4566,6 +4567,7 @@ def _get_same_padding_3d(input_size, kernel_size, stride):
45664567 pad_d, pad_h, pad_w = _get_same_padding_3d(input_size, kernel_size, stride)
45674568 padding = (pad_d[0], pad_h[0], pad_w[0]) # 对称填充
45684569 if pad_d[0] != pad_d[1] or pad_h[0] != pad_h[1] or pad_w[0] != pad_w[1]: # 非对称填充
4570+ # TODO(zrr1999) maybe mode="replicate"
45694571 x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1], pad_d[0], pad_d[1]))
45704572 padding = 0
45714573elif isinstance(padding, (list, tuple)):
0 commit comments