|
12 | 12 | use nf_locally_connected1d_layer, only: locally_connected1d_layer |
13 | 13 | use nf_maxpool1d_layer, only: maxpool1d_layer |
14 | 14 | use nf_maxpool2d_layer, only: maxpool2d_layer |
15 | | - use nf_reshape_layer, only: reshape3d_layer |
16 | 15 | use nf_reshape2d_layer, only: reshape2d_layer |
| 16 | + use nf_reshape3d_layer, only: reshape3d_layer |
17 | 17 | use nf_linear2d_layer, only: linear2d_layer |
18 | 18 | use nf_self_attention_layer, only: self_attention_layer |
19 | 19 | use nf_embedding_layer, only: embedding_layer |
@@ -229,35 +229,22 @@ module function maxpool2d(pool_size, stride) result(res) |
229 | 229 | end function maxpool2d |
230 | 230 |
|
231 | 231 |
|
232 | | - module function reshape(output_shape) result(res) |
233 | | - integer, intent(in) :: output_shape(:) |
234 | | - type(layer) :: res |
235 | | - |
236 | | - res % name = 'reshape' |
237 | | - res % layer_shape = output_shape |
238 | | - |
239 | | - if (size(output_shape) == 3) then |
240 | | - allocate(res % p, source=reshape3d_layer(output_shape)) |
241 | | - else |
242 | | - error stop 'size(output_shape) of the reshape layer must == 3' |
243 | | - end if |
244 | | - |
245 | | - end function reshape |
246 | | - |
247 | | - module function reshape2d(output_shape) result(res) |
248 | | - integer, intent(in) :: output_shape(:) |
| 232 | + module function reshape2d(dim1, dim2) result(res) |
| 233 | + integer, intent(in) :: dim1, dim2 |
249 | 234 | type(layer) :: res |
250 | | - |
251 | 235 | res % name = 'reshape2d' |
252 | | - res % layer_shape = output_shape |
| 236 | + res % layer_shape = [dim1, dim2] |
| 237 | + allocate(res % p, source=reshape2d_layer(res % layer_shape)) |
| 238 | + end function reshape2d |
253 | 239 |
|
254 | | - if (size(output_shape) == 2) then |
255 | | - allocate(res % p, source=reshape2d_layer(output_shape)) |
256 | | - else |
257 | | - error stop 'size(output_shape) of the reshape layer must == 2' |
258 | | - end if |
259 | 240 |
|
260 | | - end function reshape2d |
| 241 | + module function reshape3d(dim1, dim2, dim3) result(res) |
| 242 | + integer, intent(in) :: dim1, dim2, dim3 |
| 243 | + type(layer) :: res |
| 244 | + res % name = 'reshape3d' |
| 245 | + res % layer_shape = [dim1, dim2, dim3] |
| 246 | + allocate(res % p, source=reshape3d_layer(res % layer_shape)) |
| 247 | + end function reshape3d |
261 | 248 |
|
262 | 249 |
|
263 | 250 | module function linear2d(out_features) result(res) |
|
0 commit comments