diff --git a/python/sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py b/python/sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py index 9ca03167f9ed7b..c61c09d97975bb 100755 --- a/python/sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py +++ b/python/sparknlp/annotator/classifier_dl/deberta_for_sequence_classification.py @@ -175,7 +175,7 @@ def loadSavedModel(folder, spark_session): return DeBertaForSequenceClassification(java_model=jModel) @staticmethod - def pretrained(name="deberta_base_sequence_classifier_imdb", lang="en", remote_loc=None): + def pretrained(name="deberta_v3_xsmall_sequence_classifier_imdb", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters diff --git a/python/sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py b/python/sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py index 2db4b3b7ae7b2a..5112fa01e367d0 100755 --- a/python/sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py +++ b/python/sparknlp/annotator/classifier_dl/xlm_roberta_for_sequence_classification.py @@ -178,14 +178,14 @@ def loadSavedModel(folder, spark_session): return XlmRoBertaForSequenceClassification(java_model=jModel) @staticmethod - def pretrained(name="xlm_roberta_base_sequence_classifier_imdb", lang="en", remote_loc=None): + def pretrained(name="xlm_roberta_base_token_classifier_conll03", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional Name of the pretrained model, by default - "xlm_roberta_base_sequence_classifier_imdb" + "xlm_roberta_base_token_classifier_conll03" lang : str, optional Language of the pretrained model, by default "en" remote_loc : str, optional diff --git a/python/sparknlp/annotator/cv/florence2_transformer.py b/python/sparknlp/annotator/cv/florence2_transformer.py index 0ba21c707823d6..4617faddd6164f 100644 --- a/python/sparknlp/annotator/cv/florence2_transformer.py +++ b/python/sparknlp/annotator/cv/florence2_transformer.py @@ -174,7 +174,7 @@ def loadSavedModel(folder, spark_session, use_openvino=False): return Florence2Transformer(java_model=jModel) @staticmethod - def pretrained(name="florence2_base_ft_int4", lang="en", remote_loc=None): + def pretrained(name="florence_2_base_ft_int4", lang="en", remote_loc=None): """Downloads and loads a pretrained model.""" from sparknlp.pretrained import ResourceDownloader return ResourceDownloader.downloadModel(Florence2Transformer, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/annotator/cv/gemma3_for_multimodal.py b/python/sparknlp/annotator/cv/gemma3_for_multimodal.py index 6d6b8e5156adec..01045933e6928c 100644 --- a/python/sparknlp/annotator/cv/gemma3_for_multimodal.py +++ b/python/sparknlp/annotator/cv/gemma3_for_multimodal.py @@ -324,7 +324,7 @@ def loadSavedModel(folder, spark_session, use_openvino=False): return Gemma3ForMultiModal(java_model=jModel) @staticmethod - def pretrained(name="gemma3_4b_it_int4", lang="en", remote_loc=None): + def pretrained(name="gemma_3_4b_it_int4", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters diff --git a/python/sparknlp/annotator/seq2seq/phi4_transformer.py b/python/sparknlp/annotator/seq2seq/phi4_transformer.py index 8a55c191a30582..2db4a4712cefb7 100644 --- a/python/sparknlp/annotator/seq2seq/phi4_transformer.py +++ b/python/sparknlp/annotator/seq2seq/phi4_transformer.py @@ -56,7 +56,7 @@ class Phi4Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): ... .setInputCols(["document"]) \ ... .setOutputCol("generation") - The default model is ``"phi-4"``, if no name is provided. For available pretrained models please see the `Models Hub `__. + The default model is ``"phi_4_mini_instruct_int8_openvino"``, if no name is provided. For available pretrained models please see the `Models Hub `__. Note ---- @@ -117,7 +117,7 @@ class Phi4Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): >>> documentAssembler = DocumentAssembler() \ ... .setInputCol("text") \ ... .setOutputCol("documents") - >>> phi4 = Phi4Transformer.pretrained("phi-4") \ + >>> phi4 = Phi4Transformer.pretrained("phi_4_mini_instruct_int8_openvino") \ ... .setInputCols(["documents"]) \ ... .setMaxOutputLength(60) \ ... .setOutputCol("generation") @@ -365,13 +365,13 @@ def loadSavedModel(folder, spark_session, use_openvino = False): return Phi4Transformer(java_model=jModel) @staticmethod - def pretrained(name="phi-4", lang="en", remote_loc=None): + def pretrained(name="phi_4_mini_instruct_int8_openvino", lang="en", remote_loc=None): """Downloads and loads a pretrained model. Parameters ---------- name : str, optional - Name of the pretrained model, by default "phi-4" + Name of the pretrained model, by default "phi_4_mini_instruct_int8_openvino" lang : str, optional Language of the pretrained model, by default "en" remote_loc : str, optional @@ -384,4 +384,4 @@ def pretrained(name="phi-4", lang="en", remote_loc=None): The restored model """ from sparknlp.pretrained import ResourceDownloader - return ResourceDownloader.downloadModel(Phi4Transformer, name, lang, remote_loc) \ No newline at end of file + return ResourceDownloader.downloadModel(Phi4Transformer, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/common/properties.py b/python/sparknlp/common/properties.py index 7ed53ebe9a8e09..30277c4bf40b90 100644 --- a/python/sparknlp/common/properties.py +++ b/python/sparknlp/common/properties.py @@ -502,6 +502,29 @@ def getEngine(self): """ return self.getOrDefault(self.engine) + @classmethod + def pretrainedEngine(cls, name: str = "default", lang: str = "en", remote_loc: str = None, engine="onnx"): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + The name of the pretrained model, by default "default" + lang : str, optional + The language of the pretrained model, by default "en" + remote_loc : str, optional + Remote location of the model, by default None + engine : str, optional + The Deep Learning engine used for this model, by default "onnx" + + Returns + ------- + AnnotatorModel + Pretrained model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(cls, name, lang, remote_loc, engine) + class HasCandidateLabelsProperties: candidateLabels = Param(Params._dummy(), "candidateLabels", diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index 497cfa27964000..215af5c71ecece 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -713,13 +713,14 @@ def __init__(self, name, remote_loc="public/models", unzip=True): class _DownloadModel(ExtendedJavaWrapper): - def __init__(self, reader, name, language, remote_loc, validator): + def __init__(self, reader, name, language, remote_loc, engine, validator): super(_DownloadModel, self).__init__( "com.johnsnowlabs.nlp.pretrained." + validator + ".downloadModel", reader, name, language, remote_loc, + engine ) @@ -775,12 +776,14 @@ def __init__(self): class _GetResourceSize(ExtendedJavaWrapper): - def __init__(self, name, language, remote_loc): + def __init__(self, name, language, remote_loc, annotator, engine): super(_GetResourceSize, self).__init__( "com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.getDownloadSize", name, language, remote_loc, + annotator, + engine ) diff --git a/python/sparknlp/pretrained/resource_downloader.py b/python/sparknlp/pretrained/resource_downloader.py index 00ffd0848a275b..2326609694bd0a 100644 --- a/python/sparknlp/pretrained/resource_downloader.py +++ b/python/sparknlp/pretrained/resource_downloader.py @@ -59,7 +59,7 @@ class ResourceDownloader(object): """ @staticmethod - def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResourceDownloader'): + def downloadModel(reader, name, language, remote_loc=None, engine = "onnx", j_dwn='PythonResourceDownloader'): """Downloads and loads a model with the default downloader. Usually this method does not need to be called directly, as it is called by the `pretrained()` method of the annotator. @@ -83,7 +83,7 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource Loaded pretrained annotator/pipeline """ print(name + " download started this may take some time.") - file_size = _internal._GetResourceSize(name, language, remote_loc).apply() + file_size = _internal._GetResourceSize(name, language, remote_loc, reader.name, engine).apply() if file_size == "-1": print("Can not find the model to download please check the name!") else: @@ -92,7 +92,7 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource t1 = threading.Thread(target=printProgress, args=(lambda: stop_threads,)) t1.start() try: - j_obj = _internal._DownloadModel(reader.name, name, language, remote_loc, j_dwn).apply() + j_obj = _internal._DownloadModel(reader.name, name, language, remote_loc, engine, j_dwn).apply() except Py4JJavaError as e: sys.stdout.write("\n" + str(e)) raise e diff --git a/python/test/annotator/embeddings/camembert_embeddings_test.py b/python/test/annotator/embeddings/camembert_embeddings_test.py index a9795b7c68c550..a6d72a797eb4a8 100644 --- a/python/test/annotator/embeddings/camembert_embeddings_test.py +++ b/python/test/annotator/embeddings/camembert_embeddings_test.py @@ -48,3 +48,9 @@ def test_run(self): model = pipeline.fit(self.data) model.transform(self.data).show() + + def test_perferred_engine(self): + model_onnx = CamemBertEmbeddings.pretrainedEngine("camembert_base","fr",engine= "onnx") + model_tensorflow = CamemBertEmbeddings.pretrainedEngine("camembert_base","fr",engine= "tensorflow") + self.assertEqual(model_onnx.getEngine(), "onnx") + self.assertEqual(model_tensorflow.getEngine(), "tensorflow") diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasPretrained.scala b/src/main/scala/com/johnsnowlabs/nlp/HasPretrained.scala index 41e2315783d4bc..db7760e4cfbe42 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasPretrained.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasPretrained.scala @@ -29,6 +29,8 @@ trait HasPretrained[M <: PipelineStage] { val defaultLang: String = "en" + val defaultPreferredEngine: String = "onnx" + lazy val defaultLoc: String = ResourceDownloader.publicLoc implicit private val companion: DefaultParamsReadable[M] = @@ -38,17 +40,35 @@ trait HasPretrained[M <: PipelineStage] { s"${this.getClass.getName} does not have a default pretrained model. Please provide a model name." /** Java default argument interoperability */ - def pretrained(name: String, lang: String, remoteLoc: String): M = { + def pretrained( + name: String, + lang: String, + remoteLoc: String, + preferredEngine: String = "onnx"): M = { if (Option(name).isEmpty) throw new NotImplementedError(errorMsg) - ResourceDownloader.downloadModel(companion, name, Option(lang), remoteLoc) + ResourceDownloader.downloadModel(companion, name, Option(lang), remoteLoc, preferredEngine) } - def pretrained(name: String, lang: String): M = pretrained(name, lang, defaultLoc) + def pretrained(name: String, lang: String): M = + pretrained(name, lang, defaultLoc, defaultPreferredEngine) - def pretrained(name: String): M = pretrained(name, defaultLang, defaultLoc) + def pretrained(name: String): M = + pretrained(name, defaultLang, defaultLoc, defaultPreferredEngine) def pretrained(): M = - pretrained(defaultModelName.getOrElse(throw new Exception(errorMsg)), defaultLang, defaultLoc) + pretrained( + defaultModelName.getOrElse(throw new Exception(errorMsg)), + defaultLang, + defaultLoc, + defaultPreferredEngine) + + def pretrained(name: String, lang: String, remoteLoc: String): M = + pretrained(name, lang, remoteLoc, defaultPreferredEngine) + def pretrainedEngine(name: String, preferredEngine: String): M = + pretrained(name, defaultLang, defaultLoc, preferredEngine) + + def pretrainedEngine(name: String, lang: String, preferredEngine: String): M = + pretrained(name, lang, defaultLoc, preferredEngine) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/Florence2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/Florence2Transformer.scala index a570d2cbdf07c4..4b1b883cd938b1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/Florence2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/Florence2Transformer.scala @@ -330,7 +330,7 @@ class Florence2Transformer(override val uid: String) trait ReadablePretrainedFlorence2TransformerModel extends ParamsAndFeaturesReadable[Florence2Transformer] with HasPretrained[Florence2Transformer] { - override val defaultModelName: Some[String] = Some("florence2_base_ft_int4") + override val defaultModelName: Some[String] = Some("florence_2_base_ft_int4") /** Java compliant-overrides */ override def pretrained(): Florence2Transformer = super.pretrained() diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi4Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi4Transformer.scala index c4f9cb9df62d44..e55cd2f851eef4 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi4Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi4Transformer.scala @@ -250,12 +250,13 @@ class Phi4Transformer(override val uid: String) trait ReadablePretrainedPhi4TransformerModel extends ParamsAndFeaturesReadable[Phi4Transformer] with HasPretrained[Phi4Transformer] { - override val defaultModelName: Some[String] = Some("phi-4") + override val defaultModelName: Some[String] = Some("phi_4_mini_instruct_int8_openvino") override def pretrained(): Phi4Transformer = super.pretrained() override def pretrained(name: String): Phi4Transformer = super.pretrained(name) override def pretrained(name: String, lang: String): Phi4Transformer = super.pretrained(name, lang) + override def pretrained(name: String, lang: String, remoteLoc: String): Phi4Transformer = super.pretrained(name, lang, remoteLoc) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala index 657d012c04734e..aa05f43daa7d36 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5VEmbeddings.scala @@ -376,7 +376,7 @@ trait ReadablePretrainedE5VEmbeddings extends ParamsAndFeaturesReadable[E5VEmbeddings] with HasPretrained[E5VEmbeddings] { - override val defaultModelName: Some[String] = Some("e5v_1_5_7b_int4") + override val defaultModelName: Some[String] = Some("e5v_int4") /** Java compliant-overrides */ override def pretrained(): E5VEmbeddings = super.pretrained() diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index c0e54ca1769634..3bbc562fb953dd 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -496,8 +496,18 @@ object ResourceDownloader { reader: DefaultParamsReadable[TModel], name: String, language: Option[String] = None, - folder: String = publicLoc): TModel = { - downloadModel(reader, ResourceRequest(name, language, folder)) + folder: String = publicLoc, + preferredEngine: String = "onnx"): TModel = { + + val annotator = reader.getClass.getSimpleName.replace("$", "") + downloadModel( + reader, + ResourceRequest( + name, + language, + folder, + annotator = Some(annotator), + engine = Some(preferredEngine))) } def downloadModel[TModel <: PipelineStage]( @@ -517,7 +527,7 @@ object ResourceDownloader { name: String, language: Option[String] = None, folder: String = publicLoc): PipelineModel = { - downloadPipeline(ResourceRequest(name, language, folder)) + downloadPipeline(ResourceRequest(name, language, folder, annotator = Some("PipelineModel"))) } def downloadPipeline(request: ResourceRequest): PipelineModel = { @@ -575,7 +585,9 @@ case class ResourceRequest( language: Option[String] = None, folder: String = ResourceDownloader.publicLoc, libVersion: Version = ResourceDownloader.libVersion, - sparkVersion: Version = ResourceDownloader.sparkVersion) + sparkVersion: Version = ResourceDownloader.sparkVersion, + annotator: Option[String] = None, + engine: Option[String] = None) /* convenience accessor for Py4J calls */ object PythonResourceDownloader { @@ -632,7 +644,6 @@ object PythonResourceDownloader { "XlnetForTokenClassification" -> XlnetForTokenClassification, "AlbertForSequenceClassification" -> AlbertForSequenceClassification, "BertForSequenceClassification" -> BertForSequenceClassification, - "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification, "DistilBertForSequenceClassification" -> DistilBertForSequenceClassification, "LongformerForSequenceClassification" -> LongformerForSequenceClassification, "RoBertaForSequenceClassification" -> RoBertaForSequenceClassification, @@ -644,7 +655,6 @@ object PythonResourceDownloader { "Word2VecModel" -> Word2VecModel, "DeBertaEmbeddings" -> DeBertaEmbeddings, "DeBertaForSequenceClassification" -> DeBertaForSequenceClassification, - "DeBertaForTokenClassification" -> DeBertaForTokenClassification, "CamemBertEmbeddings" -> CamemBertEmbeddings, "AlbertForQuestionAnswering" -> AlbertForQuestionAnswering, "BertForQuestionAnswering" -> BertForQuestionAnswering, @@ -722,7 +732,8 @@ object PythonResourceDownloader { readerStr: String, name: String, language: String = null, - remoteLoc: String = null): PipelineStage = { + remoteLoc: String = null, + preferredEngine: String): PipelineStage = { val reader = keyToReader.getOrElse( if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr, @@ -734,7 +745,8 @@ object PythonResourceDownloader { reader.asInstanceOf[DefaultParamsReadable[PipelineStage]], name, Option(language), - correctedFolder) + correctedFolder, + preferredEngine) // Cast the model to the required type. This has to be done for each entry in the typeMapper map if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel") @@ -800,8 +812,19 @@ object PythonResourceDownloader { ResourceDownloader.listAvailableAnnotators().mkString("\n") } - def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = { + def getDownloadSize( + name: String, + language: String = "en", + remoteLoc: String = null, + annotator: String, + engine: String): String = { val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc) - ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder)) + ResourceDownloader.getDownloadSize( + ResourceRequest( + name, + Option(language), + correctedFolder, + annotator = Some(annotator), + engine = Some(engine))) } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala index 992708e86c0992..a90463bedfdd6f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceMetadata.scala @@ -109,15 +109,46 @@ object ResourceMetadata { candidates: List[ResourceMetadata], request: ResourceRequest): Option[ResourceMetadata] = { - val compatibleCandidates = candidates - .filter(item => - item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined - && item.name == request.name - && (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get) - && Version.isCompatible(request.libVersion, item.libVersion) - && Version.isCompatible(request.sparkVersion, item.sparkVersion)) - - val sortedResult = compatibleCandidates.sorted + val excludedAnnotators = Array("AutoGGUFModel","AutoGGUFVisionModel") + + val compatibleCandidates = candidates.filter(item => + item.readyToUse && + item.libVersion.isDefined && + item.sparkVersion.isDefined && + item.name == request.name & + (request.annotator.isEmpty || item.annotator.isEmpty || excludedAnnotators.contains(item.annotator.get) || + request.annotator.get.equalsIgnoreCase(item.annotator.get)) && + (request.language.isEmpty || item.language.isEmpty || + request.language.get == item.language.get) && + Version.isCompatible(request.libVersion, item.libVersion) && + Version.isCompatible(request.sparkVersion, item.sparkVersion)) + + val defaultPriority = Seq("onnx", "tensorflow", "openvino", "unk") + + val finalPriority = request.engine match { + case Some(pref) => + val p = pref.toLowerCase // incase user types ONNX instead of onnx + (Seq(p) ++ defaultPriority.filterNot(_ == p)).distinct + case None => + defaultPriority + } + + def enginePriority(engineOpt: Option[String]): Int = { + val engine = engineOpt.map(_.toLowerCase).getOrElse("unk") + finalPriority.indexOf(engine) match { + case -1 => + finalPriority.length // unknown engine → lowest priority ( therefore highest numerical value ) + case idx => idx + } + } + + // Engine → Spark → Lib → Time + val sortedResult = compatibleCandidates.sortWith { (a, b) => + val engineComp = enginePriority(a.engine) compare enginePriority(b.engine) + if (engineComp != 0) engineComp > 0 + else a < b // fallback to old logic + } + sortedResult.lastOption } diff --git a/src/test/resources/resource-downloader/test_engine_metadata.json b/src/test/resources/resource-downloader/test_engine_metadata.json new file mode 100644 index 00000000000000..6fff859130e734 --- /dev/null +++ b/src/test/resources/resource-downloader/test_engine_metadata.json @@ -0,0 +1,6 @@ +{"name": "bert_base_uncased", "language": "en", "libVersion": {"parts": [1, 5]}, "sparkVersion": {"parts": [2]}, "readyToUse": true, "time": "2018-03-27T19:57:33.497Z", "isZipped": true, "checksum": "", "annotator": "BertEmbeddings","engine":"unk"} +{"name": "bert_base_uncased", "language": "en", "libVersion": {"parts": [6, 1, 5]}, "sparkVersion": {"parts": [3]}, "readyToUse": true, "time": "2025-03-27T20:09:07.222Z", "isZipped": true, "checksum": "", "annotator": "BertEmbeddings","engine":"onnx"} +{"name": "bert_base_uncased", "language": "en", "libVersion": {"parts": [6, 1, 5]}, "sparkVersion": {"parts": [3]}, "readyToUse": true, "time": "2025-03-27T20:09:07.222Z", "isZipped": true, "checksum": "", "annotator": "BertEmbeddings","engine":"openvino"} +{"name": "bert_base_uncased", "language": "en", "libVersion": {"parts": [3, 3, 4]}, "sparkVersion": {"parts": [3]}, "readyToUse": true, "time": "2025-03-27T20:09:07.222Z", "isZipped": true, "checksum": "", "annotator": "BertEmbeddings","engine":"tensorflow"} +{"name": "bert_base_uncased", "language": "en", "libVersion": {"parts": [6, 2, 0]}, "sparkVersion": {"parts": [3]}, "readyToUse": true, "time": "2025-10-27T20:09:07.222Z", "isZipped": true, "checksum": "", "annotator": "BertSentenceEmbeddings","engine":"tensorflow"} +{"name": "testannotator", "language": "en", "libVersion": {"parts": [6, 0, 0]}, "sparkVersion": {"parts": [3]}, "readyToUse": true, "time": "2025-10-27T20:09:07.222Z", "isZipped": true, "checksum": "", "annotator": "BertSentenceEmbeddings","engine":"unk"} diff --git a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceMedataTest.scala b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceMedataTest.scala index 5035fed077bf1c..d821775e0f2e38 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceMedataTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceMedataTest.scala @@ -198,6 +198,65 @@ class ResourceMedataTest extends AnyFlatSpec { assert(versions.get.time == expectedTimestamp) } + it should "get correct model if two models have same name but belong to different annotators" in { + val resourcePath = "src/test/resources/resource-downloader/test_engine_metadata.json" + val mockResourceDownloader: MockResourceDownloader = new MockResourceDownloader(resourcePath) + val resourceMetadata = mockResourceDownloader.resources + val resourceRequest = ResourceRequest( + "bert_base_uncased", + libVersion = Version(List(6, 0, 0)), + sparkVersion = Version(List(3, 0)), + annotator = Some("BertEmbeddings")) + + val expectedAnnotator = "BertEmbeddings" + + val versions = ResourceMetadata.resolveResource(resourceMetadata, resourceRequest) + + assert(versions.get.annotator.get == expectedAnnotator) + } + + it should "get correct preferred engine in case model has preferred engine available" in { + val resourcePath = "src/test/resources/resource-downloader/test_engine_metadata.json" + val mockResourceDownloader: MockResourceDownloader = new MockResourceDownloader(resourcePath) + val resourceMetadata = mockResourceDownloader.resources + val resourceRequest = ResourceRequest( + "bert_base_uncased", + libVersion = Version(List(6, 0, 0)), + sparkVersion = Version(List(3, 0)), + annotator = Some("BertEmbeddings"), + engine = Some("tensorflow")) + + val expectedAnnotator = "BertEmbeddings" + val expectedEngine = "tensorflow" + + val versions = ResourceMetadata.resolveResource(resourceMetadata, resourceRequest) + + assert(versions.get.annotator.get == expectedAnnotator) + assert(versions.get.engine.get == expectedEngine) + } + + + it should "fall back to other model variant if preferred engine does not exist" in { + val resourcePath = "src/test/resources/resource-downloader/test_engine_metadata.json" + val mockResourceDownloader: MockResourceDownloader = new MockResourceDownloader(resourcePath) + val resourceMetadata = mockResourceDownloader.resources + val resourceRequest = ResourceRequest( + "testannotator", + libVersion = Version(List(6, 2, 0)), + sparkVersion = Version(List(3, 0)), + annotator = Some("BertSentenceEmbeddings"), + engine = Some("tensorflow")) + + val expectedAnnotator = "BertSentenceEmbeddings" + val expectedEngine = "unk" + + val versions = ResourceMetadata.resolveResource(resourceMetadata, resourceRequest) + + assert(versions.get.annotator.get == expectedAnnotator) + assert(versions.get.engine.get == expectedEngine) + } + + private def getTimestamp(date: String): Timestamp = { val UTC = TimeZone.getTimeZone("UTC") val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")