diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index c135e10b..7d82edbe 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -4,6 +4,7 @@ import dataclasses import importlib import urllib +import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from langchain.prompts import ChatPromptTemplate @@ -214,15 +215,26 @@ def _repr_html_(self) -> str: def filter( self, - Type: Optional[str], + *, + type: Optional[str] = None, dataset_id: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, + Type: Optional[str] = None, # For backwards compatibility: ) -> Registry: """Filter the tasks in the registry.""" tasks = self.tasks - if Type is not None: - tasks = [task for task in tasks if task.__class__.__name__ == Type] + if Type and type: + raise ValueError("Cannot filter by both Type and type.") + + if Type is not None and type is None: + type_ = Type + warnings.warn("Type is deprecated, please use type instead.") + else: + type_ = type + + if type_ is not None: + tasks = [task for task in tasks if task.__class__.__name__ == type] if dataset_id is not None: tasks = [task for task in tasks if task.dataset_id == dataset_id] if name is not None: