@@ -33,7 +33,7 @@ public func sigmoid(_ x: MLXArray) -> MLXArray {
3333/// - <doc:activations>
3434/// - ``ReLU``
3535public func relu( _ x: MLXArray ) -> MLXArray {
36- maximum ( x , 0 )
36+ compiledRelu ( x )
3737}
3838
3939/// Applies the Leaky Rectified Linear Unit.
@@ -103,6 +103,21 @@ public func relu6(_ x: MLXArray) -> MLXArray {
103103 compiledRelu6 ( x)
104104}
105105
106+ /// Applies the squared Rectified Linear Unit.
107+ ///
108+ /// This is:
109+ ///
110+ /// ```swift
111+ /// MLX.relu(x).square()
112+ /// ```
113+ ///
114+ /// /// ### See Also
115+ /// - <doc:activations>
116+ /// - ``ReLUSquared``
117+ public func reluSquared( _ x: MLXArray ) -> MLXArray {
118+ compiledReluSquared ( x)
119+ }
120+
106121@available ( * , deprecated, renamed: " softplus(_:) " )
107122@_documentation ( visibility: internal)
108123public func softPlus( _ x: MLXArray ) -> MLXArray {
@@ -458,6 +473,23 @@ open class ReLU6: Module, UnaryLayer {
458473 }
459474}
460475
476+ /// Applies the squared Rectified Linear Unit.
477+ ///
478+ /// This is:
479+ ///
480+ /// ```swift
481+ /// MLX.maximum(x, 0).square()
482+ ///
483+ ///
484+ /// ### See Also
485+ /// - <doc:activations>
486+ /// - ``reluSquared(_:)``
487+ open class ReLUSquared : Module , UnaryLayer {
488+ open func callAsFunction( _ x: MLXArray ) -> MLXArray {
489+ reluSquared ( x)
490+ }
491+ }
492+
461493@available ( * , deprecated, renamed: " Softmax " )
462494@_documentation ( visibility: internal)
463495open class SoftMax : Module , UnaryLayer {
@@ -852,3 +884,15 @@ private let compiledHardSwish: @Sendable (MLXArray) -> MLXArray = {
852884 return x * minimum( maxXPlus3, 6 ) / 6
853885 }
854886} ( )
887+
888+ private let compiledRelu : @Sendable ( MLXArray ) -> MLXArray = {
889+ compile ( shapeless: true ) { x in
890+ maximum ( x, 0 )
891+ }
892+ } ( )
893+
894+ private let compiledReluSquared : @Sendable ( MLXArray ) -> MLXArray = {
895+ compile ( shapeless: true ) { x in
896+ return relu ( x) . square ( )
897+ }
898+ } ( )
0 commit comments