Skip to content

Commit f9035d9

Browse files
Merge pull request #544 from zrr1999/acc/pool
fix PoolRule
2 parents 3128bf1 + 23883bc commit f9035d9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tester/paddle_to_torch/rules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4469,6 +4469,7 @@ def _get_same_padding_1d(input_size, kernel_size, stride):
44694469
pad_left, pad_right = _get_same_padding_1d(input_size, kernel_size, stride)
44704470
padding = pad_left # 对称填充
44714471
if pad_left != pad_right: # 非对称填充
4472+
# TODO(zrr1999) maybe mode="replicate"
44724473
x = torch.nn.functional.pad(x, (pad_left, pad_right))
44734474
padding = 0
44744475
elif isinstance(padding, (list, tuple)):
@@ -4508,7 +4509,7 @@ def _get_same_padding_2d(input_size, kernel_size, stride):
45084509
pad_h, pad_w = _get_same_padding_2d(input_size, kernel_size, stride)
45094510
padding = (pad_h[0], pad_w[0]) # 对称填充
45104511
if pad_h[0] != pad_h[1] or pad_w[0] != pad_w[1]: # 非对称填充
4511-
x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1]))
4512+
x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1]), mode="replicate")
45124513
padding = 0
45134514
elif isinstance(padding, (list, tuple)):
45144515
if len(padding) == 2: # [pad_height, pad_width]
@@ -4566,6 +4567,7 @@ def _get_same_padding_3d(input_size, kernel_size, stride):
45664567
pad_d, pad_h, pad_w = _get_same_padding_3d(input_size, kernel_size, stride)
45674568
padding = (pad_d[0], pad_h[0], pad_w[0]) # 对称填充
45684569
if pad_d[0] != pad_d[1] or pad_h[0] != pad_h[1] or pad_w[0] != pad_w[1]: # 非对称填充
4570+
# TODO(zrr1999) maybe mode="replicate"
45694571
x = torch.nn.functional.pad(x, (pad_w[0], pad_w[1], pad_h[0], pad_h[1], pad_d[0], pad_d[1]))
45704572
padding = 0
45714573
elif isinstance(padding, (list, tuple)):

0 commit comments

Comments
 (0)