@@ -44,13 +44,14 @@ class SEFR(LinearClassifierMixin, BaseEstimator):
4444 Specifies if a constant (a.k.a. bias or intercept) should be
4545 added to the decision function.
4646
47- kernel : {'linear', 'poly', 'rbf', 'sigmoid'} or callable, default='linear'
47+ kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'precomputed' } or callable, default='linear'
4848 Specifies the kernel type to be used in the algorithm.
4949 If a callable is given, it is used to pre-compute the kernel matrix.
50+ If 'precomputed', X is assumed to be a kernel matrix.
5051
5152 gamma : float, default=None
5253 Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. If None, then it is
53- set to 1.0 / n_features.
54+ set to 1.0 / n_features. Ignored when kernel='precomputed'.
5455
5556 degree : int, default=3
5657 Degree for 'poly' kernels. Ignored by other kernels.
@@ -80,7 +81,7 @@ class SEFR(LinearClassifierMixin, BaseEstimator):
8081 has feature names that are all strings.
8182
8283 X_fit_ : ndarray of shape (n_samples, n_features)
83- The training data, stored when a kernel is used.
84+ The training data, stored when a kernel is used (except for 'precomputed') .
8485
8586 Notes
8687 -----
@@ -100,7 +101,10 @@ class SEFR(LinearClassifierMixin, BaseEstimator):
100101
101102 _parameter_constraints : dict = {
102103 "fit_intercept" : ["boolean" ],
103- "kernel" : [StrOptions ({"linear" , "poly" , "rbf" , "sigmoid" }), callable ],
104+ "kernel" : [
105+ StrOptions ({"linear" , "poly" , "rbf" , "sigmoid" , "precomputed" }),
106+ callable ,
107+ ],
104108 "gamma" : [Interval (Real , 0 , None , closed = "left" ), None ],
105109 "degree" : [Interval (Integral , 1 , None , closed = "left" ), None ],
106110 "coef0" : [Real , None ],
@@ -144,28 +148,58 @@ def _more_tags(self) -> dict[str, bool]:
144148 }
145149
146150 def _check_X (self , X ) -> np .ndarray :
147- X = validate_data (
148- self ,
149- X ,
150- dtype = "numeric" ,
151- force_all_finite = True ,
152- reset = False ,
153- )
154- if X .shape [1 ] != self .n_features_in_ :
155- raise ValueError (
156- "Expected input with %d features, got %d instead."
157- % (self .n_features_in_ , X .shape [1 ])
151+ if self .kernel == "precomputed" :
152+ X = validate_data (
153+ self ,
154+ X ,
155+ dtype = "numeric" ,
156+ force_all_finite = True ,
157+ reset = False ,
158+ )
159+ # For precomputed kernels during prediction, X should be (n_test_samples, n_train_samples)
160+ if hasattr (self , "n_features_in_" ) and X .shape [1 ] != self .n_features_in_ :
161+ raise ValueError (
162+ f"Precomputed kernel matrix should have { self .n_features_in_ } columns "
163+ f"(number of training samples), got { X .shape [1 ]} ."
164+ )
165+ else :
166+ X = validate_data (
167+ self ,
168+ X ,
169+ dtype = "numeric" ,
170+ force_all_finite = True ,
171+ reset = False ,
158172 )
173+ if hasattr (self , "n_features_in_" ) and X .shape [1 ] != self .n_features_in_ :
174+ raise ValueError (
175+ "Expected input with %d features, got %d instead."
176+ % (self .n_features_in_ , X .shape [1 ])
177+ )
159178 return X
160179
161180 def _check_X_y (self , X , y ) -> tuple [np .ndarray , np .ndarray ]:
162- X , y = check_X_y (
163- X ,
164- y ,
165- dtype = "numeric" ,
166- force_all_finite = True ,
167- estimator = self ,
168- )
181+ if self .kernel == "precomputed" :
182+ # For precomputed kernels, X should be a square kernel matrix
183+ X , y = check_X_y (
184+ X ,
185+ y ,
186+ dtype = "numeric" ,
187+ force_all_finite = True ,
188+ estimator = self ,
189+ )
190+ if X .shape [0 ] != X .shape [1 ]:
191+ raise ValueError (
192+ f"Precomputed kernel matrix should be square, got shape { X .shape } ."
193+ )
194+ else :
195+ X , y = check_X_y (
196+ X ,
197+ y ,
198+ dtype = "numeric" ,
199+ force_all_finite = True ,
200+ estimator = self ,
201+ )
202+
169203 check_classification_targets (y )
170204
171205 if np .unique (y ).shape [0 ] == 1 :
@@ -180,6 +214,10 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
180214 return X , y
181215
182216 def _get_kernel_matrix (self , X , Y = None ):
217+ if self .kernel == "precomputed" :
218+ # X is already a kernel matrix
219+ return X
220+
183221 if Y is None :
184222 Y = self .X_fit_
185223
@@ -203,9 +241,10 @@ def fit(self, X, y, sample_weight=None) -> Self:
203241
204242 Parameters
205243 ----------
206- X : {array-like, sparse matrix} of shape (n_samples, n_features)
244+ X : {array-like, sparse matrix} of shape (n_samples, n_features) or (n_samples, n_samples)
207245 Training vector, where `n_samples` is the number of samples and
208246 `n_features` is the number of features.
247+ If kernel='precomputed', X should be a square kernel matrix.
209248
210249 y : array-like of shape (n_samples,)
211250 Target vector relative to X.
@@ -219,15 +258,25 @@ def fit(self, X, y, sample_weight=None) -> Self:
219258 self
220259 Fitted estimator.
221260 """
222- _check_n_features (self , X = X , reset = True )
223- _check_feature_names (self , X = X , reset = True )
261+ if self .kernel == "precomputed" :
262+ _check_n_features (self , X = X , reset = True )
263+ _check_feature_names (self , X = X , reset = True )
264+ else :
265+ _check_n_features (self , X = X , reset = True )
266+ _check_feature_names (self , X = X , reset = True )
224267
225268 X , y = self ._check_X_y (X , y )
226- self .X_fit_ = X
269+
270+ # Store training data only for non-precomputed kernels
271+ if self .kernel != "precomputed" :
272+ self .X_fit_ = X
273+
227274 self .classes_ , y_ = np .unique (y , return_inverse = True )
228275
229276 if self .kernel == "linear" :
230277 K = X
278+ elif self .kernel == "precomputed" :
279+ K = X # X is already the kernel matrix
231280 else :
232281 K = self ._get_kernel_matrix (X )
233282
@@ -277,10 +326,14 @@ def fit(self, X, y, sample_weight=None) -> Self:
277326 def decision_function (self , X ):
278327 check_is_fitted (self )
279328 X = self ._check_X (X )
329+
280330 if self .kernel == "linear" :
281331 K = X
332+ elif self .kernel == "precomputed" :
333+ K = X # X is already a kernel matrix
282334 else :
283335 K = self ._get_kernel_matrix (X )
336+
284337 return (
285338 safe_sparse_dot (K , self .coef_ .T , dense_output = True ) + self .intercept_
286339 ).ravel ()
@@ -294,9 +347,10 @@ def predict_proba(self, X):
294347
295348 Parameters
296349 ----------
297- X : array-like of shape (n_samples, n_features)
350+ X : array-like of shape (n_samples, n_features) or (n_samples, n_train_samples)
298351 Vector to be scored, where `n_samples` is the number of samples and
299352 `n_features` is the number of features.
353+ If kernel='precomputed', X should have shape (n_samples, n_train_samples).
300354
301355 Returns
302356 -------
@@ -324,9 +378,10 @@ def predict_log_proba(self, X):
324378
325379 Parameters
326380 ----------
327- X : array-like of shape (n_samples, n_features)
381+ X : array-like of shape (n_samples, n_features) or (n_samples, n_train_samples)
328382 Vector to be scored, where `n_samples` is the number of samples and
329383 `n_features` is the number of features.
384+ If kernel='precomputed', X should have shape (n_samples, n_train_samples).
330385
331386 Returns
332387 -------
0 commit comments