Skip to content

Commit 6ed81fd

Browse files
committed
Addition of get_optimizer_by_name and of get_name for optimizer DT
1 parent 5110fc6 commit 6ed81fd

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

src/nf/nf_optimizers.f90

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module nf_optimizers
1818
type, abstract :: optimizer_base_type
1919
real :: learning_rate = 0.01
2020
contains
21+
procedure :: get_name
2122
procedure(init), deferred :: init
2223
procedure(minimize), deferred :: minimize
2324
end type optimizer_base_type
@@ -312,4 +313,52 @@ pure subroutine minimize_adagrad(self, param, gradient)
312313

313314
end subroutine minimize_adagrad
314315

315-
end module nf_optimizers
316+
317+
! Utility Functions
318+
!! Returns the default optimizer corresponding to the provided name
319+
pure function get_optimizer_by_name(optimizer_name) result(res)
320+
character(len=*), intent(in) :: optimizer_name
321+
class(optimizer_base_type), allocatable :: res
322+
323+
select case(trim(optimizer_name))
324+
case('adagrad')
325+
allocate ( res, source = adagrad() )
326+
327+
case('adam')
328+
allocate ( res, source = adam() )
329+
330+
case('rmsprop')
331+
allocate ( res, source = rmsprop() )
332+
333+
case('sgd')
334+
allocate ( res, source = sgd() )
335+
336+
case default
337+
error stop 'optimizer_name must be one of: ' // &
338+
'"adagrad", "adam", "rmsprop", "sgd".'
339+
end select
340+
341+
end function get_optimizer_by_name
342+
343+
344+
!! Returns the name of the optimizer
345+
pure function get_name(self) result(name)
346+
class(optimizer_base_type), intent(in) :: self
347+
character(:), allocatable :: name
348+
349+
select type (self)
350+
class is (adagrad)
351+
name = 'adagrad'
352+
class is (adam)
353+
name = 'adam'
354+
class is (rmsprop)
355+
name = 'rmsprop'
356+
class is (sgd)
357+
name = 'sgd'
358+
class default
359+
error stop 'Unknown optimizer type.'
360+
end select
361+
362+
end function get_name
363+
364+
end module nf_optimizers

0 commit comments

Comments
 (0)