@@ -18,6 +18,7 @@ module nf_optimizers
1818 type, abstract :: optimizer_base_type
1919 real :: learning_rate = 0.01
2020 contains
21+ procedure :: get_name
2122 procedure (init), deferred :: init
2223 procedure (minimize), deferred :: minimize
2324 end type optimizer_base_type
@@ -312,4 +313,52 @@ pure subroutine minimize_adagrad(self, param, gradient)
312313
313314 end subroutine minimize_adagrad
314315
315- end module nf_optimizers
316+
317+ ! Utility Functions
318+ ! ! Returns the default optimizer corresponding to the provided name
319+ pure function get_optimizer_by_name (optimizer_name ) result(res)
320+ character (len=* ), intent (in ) :: optimizer_name
321+ class(optimizer_base_type), allocatable :: res
322+
323+ select case (trim (optimizer_name))
324+ case (' adagrad' )
325+ allocate ( res, source = adagrad() )
326+
327+ case (' adam' )
328+ allocate ( res, source = adam() )
329+
330+ case (' rmsprop' )
331+ allocate ( res, source = rmsprop() )
332+
333+ case (' sgd' )
334+ allocate ( res, source = sgd() )
335+
336+ case default
337+ error stop ' optimizer_name must be one of: ' // &
338+ ' "adagrad", "adam", "rmsprop", "sgd".'
339+ end select
340+
341+ end function get_optimizer_by_name
342+
343+
344+ ! ! Returns the name of the optimizer
345+ pure function get_name (self ) result(name)
346+ class(optimizer_base_type), intent (in ) :: self
347+ character (:), allocatable :: name
348+
349+ select type (self)
350+ class is (adagrad)
351+ name = ' adagrad'
352+ class is (adam)
353+ name = ' adam'
354+ class is (rmsprop)
355+ name = ' rmsprop'
356+ class is (sgd)
357+ name = ' sgd'
358+ class default
359+ error stop ' Unknown optimizer type.'
360+ end select
361+
362+ end function get_name
363+
364+ end module nf_optimizers
0 commit comments