Skip to content

Commit b79c74c

Browse files
authored
Add ReLUSquared Activation Function (#250)
* feat: Add ReLUSquared activation function and refactor relu implementation * feat: Document ReLUSquared activation function in activations and MLXNN files
1 parent b74974b commit b79c74c

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

Source/MLXNN/Activations.swift

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public func sigmoid(_ x: MLXArray) -> MLXArray {
3333
/// - <doc:activations>
3434
/// - ``ReLU``
3535
public func relu(_ x: MLXArray) -> MLXArray {
36-
maximum(x, 0)
36+
compiledRelu(x)
3737
}
3838

3939
/// Applies the Leaky Rectified Linear Unit.
@@ -103,6 +103,21 @@ public func relu6(_ x: MLXArray) -> MLXArray {
103103
compiledRelu6(x)
104104
}
105105

106+
/// Applies the squared Rectified Linear Unit.
107+
///
108+
/// This is:
109+
///
110+
/// ```swift
111+
/// MLX.relu(x).square()
112+
/// ```
113+
///
114+
/// /// ### See Also
115+
/// - <doc:activations>
116+
/// - ``ReLUSquared``
117+
public func reluSquared(_ x: MLXArray) -> MLXArray {
118+
compiledReluSquared(x)
119+
}
120+
106121
@available(*, deprecated, renamed: "softplus(_:)")
107122
@_documentation(visibility: internal)
108123
public func softPlus(_ x: MLXArray) -> MLXArray {
@@ -458,6 +473,23 @@ open class ReLU6: Module, UnaryLayer {
458473
}
459474
}
460475

476+
/// Applies the squared Rectified Linear Unit.
477+
///
478+
/// This is:
479+
///
480+
/// ```swift
481+
/// MLX.maximum(x, 0).square()
482+
///
483+
///
484+
/// ### See Also
485+
/// - <doc:activations>
486+
/// - ``reluSquared(_:)``
487+
open class ReLUSquared: Module, UnaryLayer {
488+
open func callAsFunction(_ x: MLXArray) -> MLXArray {
489+
reluSquared(x)
490+
}
491+
}
492+
461493
@available(*, deprecated, renamed: "Softmax")
462494
@_documentation(visibility: internal)
463495
open class SoftMax: Module, UnaryLayer {
@@ -852,3 +884,15 @@ private let compiledHardSwish: @Sendable (MLXArray) -> MLXArray = {
852884
return x * minimum(maxXPlus3, 6) / 6
853885
}
854886
}()
887+
888+
private let compiledRelu: @Sendable (MLXArray) -> MLXArray = {
889+
compile(shapeless: true) { x in
890+
maximum(x, 0)
891+
}
892+
}()
893+
894+
private let compiledReluSquared: @Sendable (MLXArray) -> MLXArray = {
895+
compile(shapeless: true) { x in
896+
return relu(x).square()
897+
}
898+
}()

Source/MLXNN/Documentation.docc/MLXNN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ These can be used with ``Sequential``.
128128
- ``prelu(_:alpha:)``
129129
- ``relu(_:)``
130130
- ``relu6(_:)``
131+
- ``reluSquared(_:)``
131132
- ``selu(_:)``
132133
- ``silu(_:)``
133134
- ``sigmoid(_:)``
@@ -148,6 +149,7 @@ These can be used with ``Sequential``.
148149
- ``PReLU``
149150
- ``ReLU``
150151
- ``ReLU6``
152+
- ``ReLUSquared``
151153
- ``SELU``
152154
- ``SiLU``
153155
- ``Sigmoid``

Source/MLXNN/Documentation.docc/activations.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ such as `alpha`.
2525
- ``prelu(_:alpha:)``
2626
- ``relu(_:)``
2727
- ``relu6(_:)``
28+
- ``reluSquared(_:)``
2829
- ``selu(_:)``
2930
- ``silu(_:)``
3031
- ``sigmoid(_:)``
@@ -45,6 +46,7 @@ such as `alpha`.
4546
- ``PReLU``
4647
- ``ReLU``
4748
- ``ReLU6``
49+
- ``ReLUSquared``
4850
- ``SELU``
4951
- ``SiLU``
5052
- ``Sigmoid``

0 commit comments

Comments
 (0)