From 0fb3ccd09e27e0ac3a226d77d17661e614353c88 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 11 Dec 2023 22:53:40 -0500 Subject: [PATCH] x --- langchain_benchmarks/schema.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/langchain_benchmarks/schema.py b/langchain_benchmarks/schema.py index c135e10b..7d82edbe 100644 --- a/langchain_benchmarks/schema.py +++ b/langchain_benchmarks/schema.py @@ -4,6 +4,7 @@ import dataclasses import importlib import urllib +import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from langchain.prompts import ChatPromptTemplate @@ -214,15 +215,26 @@ def _repr_html_(self) -> str: def filter( self, - Type: Optional[str], + *, + type: Optional[str] = None, dataset_id: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, + Type: Optional[str] = None, # For backwards compatibility: ) -> Registry: """Filter the tasks in the registry.""" tasks = self.tasks - if Type is not None: - tasks = [task for task in tasks if task.__class__.__name__ == Type] + if Type and type: + raise ValueError("Cannot filter by both Type and type.") + + if Type is not None and type is None: + type_ = Type + warnings.warn("Type is deprecated, please use type instead.") + else: + type_ = type + + if type_ is not None: + tasks = [task for task in tasks if task.__class__.__name__ == type] if dataset_id is not None: tasks = [task for task in tasks if task.dataset_id == dataset_id] if name is not None: