Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def loadSavedModel(folder, spark_session):
return DeBertaForSequenceClassification(java_model=jModel)

@staticmethod
def pretrained(name="deberta_base_sequence_classifier_imdb", lang="en", remote_loc=None):
def pretrained(name="deberta_v3_xsmall_sequence_classifier_imdb", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ def loadSavedModel(folder, spark_session):
return XlmRoBertaForSequenceClassification(java_model=jModel)

@staticmethod
def pretrained(name="xlm_roberta_base_sequence_classifier_imdb", lang="en", remote_loc=None):
def pretrained(name="xlm_roberta_base_token_classifier_conll03", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
----------
name : str, optional
Name of the pretrained model, by default
"xlm_roberta_base_sequence_classifier_imdb"
"xlm_roberta_base_token_classifier_conll03"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Expand Down
2 changes: 1 addition & 1 deletion python/sparknlp/annotator/cv/florence2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def loadSavedModel(folder, spark_session, use_openvino=False):
return Florence2Transformer(java_model=jModel)

@staticmethod
def pretrained(name="florence2_base_ft_int4", lang="en", remote_loc=None):
def pretrained(name="florence_2_base_ft_int4", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model."""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(Florence2Transformer, name, lang, remote_loc)
2 changes: 1 addition & 1 deletion python/sparknlp/annotator/cv/gemma3_for_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def loadSavedModel(folder, spark_session, use_openvino=False):
return Gemma3ForMultiModal(java_model=jModel)

@staticmethod
def pretrained(name="gemma3_4b_it_int4", lang="en", remote_loc=None):
def pretrained(name="gemma_3_4b_it_int4", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
Expand Down
10 changes: 5 additions & 5 deletions python/sparknlp/annotator/seq2seq/phi4_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Phi4Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
... .setInputCols(["document"]) \
... .setOutputCol("generation")

The default model is ``"phi-4"``, if no name is provided. For available pretrained models please see the `Models Hub <https://huggingface.co/microsoft/phi-4>`__.
The default model is ``"phi_4_mini_instruct_int8_openvino"``, if no name is provided. For available pretrained models please see the `Models Hub <https://huggingface.co/microsoft/phi-4>`__.

Note
----
Expand Down Expand Up @@ -117,7 +117,7 @@ class Phi4Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
>>> documentAssembler = DocumentAssembler() \
... .setInputCol("text") \
... .setOutputCol("documents")
>>> phi4 = Phi4Transformer.pretrained("phi-4") \
>>> phi4 = Phi4Transformer.pretrained("phi_4_mini_instruct_int8_openvino") \
... .setInputCols(["documents"]) \
... .setMaxOutputLength(60) \
... .setOutputCol("generation")
Expand Down Expand Up @@ -365,13 +365,13 @@ def loadSavedModel(folder, spark_session, use_openvino = False):
return Phi4Transformer(java_model=jModel)

@staticmethod
def pretrained(name="phi-4", lang="en", remote_loc=None):
def pretrained(name="phi_4_mini_instruct_int8_openvino", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
----------
name : str, optional
Name of the pretrained model, by default "phi-4"
Name of the pretrained model, by default "phi_4_mini_instruct_int8_openvino"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Expand All @@ -384,4 +384,4 @@ def pretrained(name="phi-4", lang="en", remote_loc=None):
The restored model
"""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(Phi4Transformer, name, lang, remote_loc)
return ResourceDownloader.downloadModel(Phi4Transformer, name, lang, remote_loc)
23 changes: 23 additions & 0 deletions python/sparknlp/common/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,29 @@ def getEngine(self):
"""
return self.getOrDefault(self.engine)

@classmethod
def pretrainedEngine(cls, name: str = "default", lang: str = "en", remote_loc: str = None, engine="onnx"):
"""Downloads and loads a pretrained model.

Parameters
----------
name : str, optional
The name of the pretrained model, by default "default"
lang : str, optional
The language of the pretrained model, by default "en"
remote_loc : str, optional
Remote location of the model, by default None
engine : str, optional
The Deep Learning engine used for this model, by default "onnx"

Returns
-------
AnnotatorModel
Pretrained model
"""
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(cls, name, lang, remote_loc, engine)


class HasCandidateLabelsProperties:
candidateLabels = Param(Params._dummy(), "candidateLabels",
Expand Down
7 changes: 5 additions & 2 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,14 @@ def __init__(self, name, remote_loc="public/models", unzip=True):


class _DownloadModel(ExtendedJavaWrapper):
def __init__(self, reader, name, language, remote_loc, validator):
def __init__(self, reader, name, language, remote_loc, engine, validator):
super(_DownloadModel, self).__init__(
"com.johnsnowlabs.nlp.pretrained." + validator + ".downloadModel",
reader,
name,
language,
remote_loc,
engine
)


Expand Down Expand Up @@ -775,12 +776,14 @@ def __init__(self):


class _GetResourceSize(ExtendedJavaWrapper):
def __init__(self, name, language, remote_loc):
def __init__(self, name, language, remote_loc, annotator, engine):
super(_GetResourceSize, self).__init__(
"com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.getDownloadSize",
name,
language,
remote_loc,
annotator,
engine
)


Expand Down
6 changes: 3 additions & 3 deletions python/sparknlp/pretrained/resource_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class ResourceDownloader(object):
"""

@staticmethod
def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResourceDownloader'):
def downloadModel(reader, name, language, remote_loc=None, engine = "onnx", j_dwn='PythonResourceDownloader'):
"""Downloads and loads a model with the default downloader. Usually this method
does not need to be called directly, as it is called by the `pretrained()`
method of the annotator.
Expand All @@ -83,7 +83,7 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource
Loaded pretrained annotator/pipeline
"""
print(name + " download started this may take some time.")
file_size = _internal._GetResourceSize(name, language, remote_loc).apply()
file_size = _internal._GetResourceSize(name, language, remote_loc, reader.name, engine).apply()
if file_size == "-1":
print("Can not find the model to download please check the name!")
else:
Expand All @@ -92,7 +92,7 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource
t1 = threading.Thread(target=printProgress, args=(lambda: stop_threads,))
t1.start()
try:
j_obj = _internal._DownloadModel(reader.name, name, language, remote_loc, j_dwn).apply()
j_obj = _internal._DownloadModel(reader.name, name, language, remote_loc, engine, j_dwn).apply()
except Py4JJavaError as e:
sys.stdout.write("\n" + str(e))
raise e
Expand Down
6 changes: 6 additions & 0 deletions python/test/annotator/embeddings/camembert_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ def test_run(self):

model = pipeline.fit(self.data)
model.transform(self.data).show()

def test_perferred_engine(self):
model_onnx = CamemBertEmbeddings.pretrainedEngine("camembert_base","fr",engine= "onnx")
model_tensorflow = CamemBertEmbeddings.pretrainedEngine("camembert_base","fr",engine= "tensorflow")
self.assertEqual(model_onnx.getEngine(), "onnx")
self.assertEqual(model_tensorflow.getEngine(), "tensorflow")
30 changes: 25 additions & 5 deletions src/main/scala/com/johnsnowlabs/nlp/HasPretrained.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ trait HasPretrained[M <: PipelineStage] {

val defaultLang: String = "en"

val defaultPreferredEngine: String = "onnx"

lazy val defaultLoc: String = ResourceDownloader.publicLoc

implicit private val companion: DefaultParamsReadable[M] =
Expand All @@ -38,17 +40,35 @@ trait HasPretrained[M <: PipelineStage] {
s"${this.getClass.getName} does not have a default pretrained model. Please provide a model name."

/** Java default argument interoperability */
def pretrained(name: String, lang: String, remoteLoc: String): M = {
def pretrained(
name: String,
lang: String,
remoteLoc: String,
preferredEngine: String = "onnx"): M = {
if (Option(name).isEmpty)
throw new NotImplementedError(errorMsg)
ResourceDownloader.downloadModel(companion, name, Option(lang), remoteLoc)
ResourceDownloader.downloadModel(companion, name, Option(lang), remoteLoc, preferredEngine)
}

def pretrained(name: String, lang: String): M = pretrained(name, lang, defaultLoc)
def pretrained(name: String, lang: String): M =
pretrained(name, lang, defaultLoc, defaultPreferredEngine)

def pretrained(name: String): M = pretrained(name, defaultLang, defaultLoc)
def pretrained(name: String): M =
pretrained(name, defaultLang, defaultLoc, defaultPreferredEngine)

def pretrained(): M =
pretrained(defaultModelName.getOrElse(throw new Exception(errorMsg)), defaultLang, defaultLoc)
pretrained(
defaultModelName.getOrElse(throw new Exception(errorMsg)),
defaultLang,
defaultLoc,
defaultPreferredEngine)

def pretrained(name: String, lang: String, remoteLoc: String): M =
pretrained(name, lang, remoteLoc, defaultPreferredEngine)
def pretrainedEngine(name: String, preferredEngine: String): M =
pretrained(name, defaultLang, defaultLoc, preferredEngine)

def pretrainedEngine(name: String, lang: String, preferredEngine: String): M =
pretrained(name, lang, defaultLoc, preferredEngine)

}
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class Florence2Transformer(override val uid: String)
trait ReadablePretrainedFlorence2TransformerModel
extends ParamsAndFeaturesReadable[Florence2Transformer]
with HasPretrained[Florence2Transformer] {
override val defaultModelName: Some[String] = Some("florence2_base_ft_int4")
override val defaultModelName: Some[String] = Some("florence_2_base_ft_int4")

/** Java compliant-overrides */
override def pretrained(): Florence2Transformer = super.pretrained()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,13 @@ class Phi4Transformer(override val uid: String)
trait ReadablePretrainedPhi4TransformerModel
extends ParamsAndFeaturesReadable[Phi4Transformer]
with HasPretrained[Phi4Transformer] {
override val defaultModelName: Some[String] = Some("phi-4")
override val defaultModelName: Some[String] = Some("phi_4_mini_instruct_int8_openvino")

override def pretrained(): Phi4Transformer = super.pretrained()
override def pretrained(name: String): Phi4Transformer = super.pretrained(name)
override def pretrained(name: String, lang: String): Phi4Transformer =
super.pretrained(name, lang)

override def pretrained(name: String, lang: String, remoteLoc: String): Phi4Transformer =
super.pretrained(name, lang, remoteLoc)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ trait ReadablePretrainedE5VEmbeddings
extends ParamsAndFeaturesReadable[E5VEmbeddings]
with HasPretrained[E5VEmbeddings] {

override val defaultModelName: Some[String] = Some("e5v_1_5_7b_int4")
override val defaultModelName: Some[String] = Some("e5v_int4")

/** Java compliant-overrides */
override def pretrained(): E5VEmbeddings = super.pretrained()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,18 @@ object ResourceDownloader {
reader: DefaultParamsReadable[TModel],
name: String,
language: Option[String] = None,
folder: String = publicLoc): TModel = {
downloadModel(reader, ResourceRequest(name, language, folder))
folder: String = publicLoc,
preferredEngine: String = "onnx"): TModel = {

val annotator = reader.getClass.getSimpleName.replace("$", "")
downloadModel(
reader,
ResourceRequest(
name,
language,
folder,
annotator = Some(annotator),
engine = Some(preferredEngine)))
}

def downloadModel[TModel <: PipelineStage](
Expand All @@ -517,7 +527,7 @@ object ResourceDownloader {
name: String,
language: Option[String] = None,
folder: String = publicLoc): PipelineModel = {
downloadPipeline(ResourceRequest(name, language, folder))
downloadPipeline(ResourceRequest(name, language, folder, annotator = Some("PipelineModel")))
}

def downloadPipeline(request: ResourceRequest): PipelineModel = {
Expand Down Expand Up @@ -575,7 +585,9 @@ case class ResourceRequest(
language: Option[String] = None,
folder: String = ResourceDownloader.publicLoc,
libVersion: Version = ResourceDownloader.libVersion,
sparkVersion: Version = ResourceDownloader.sparkVersion)
sparkVersion: Version = ResourceDownloader.sparkVersion,
annotator: Option[String] = None,
engine: Option[String] = None)

/* convenience accessor for Py4J calls */
object PythonResourceDownloader {
Expand Down Expand Up @@ -632,7 +644,6 @@ object PythonResourceDownloader {
"XlnetForTokenClassification" -> XlnetForTokenClassification,
"AlbertForSequenceClassification" -> AlbertForSequenceClassification,
"BertForSequenceClassification" -> BertForSequenceClassification,
"DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
"DistilBertForSequenceClassification" -> DistilBertForSequenceClassification,
"LongformerForSequenceClassification" -> LongformerForSequenceClassification,
"RoBertaForSequenceClassification" -> RoBertaForSequenceClassification,
Expand All @@ -644,7 +655,6 @@ object PythonResourceDownloader {
"Word2VecModel" -> Word2VecModel,
"DeBertaEmbeddings" -> DeBertaEmbeddings,
"DeBertaForSequenceClassification" -> DeBertaForSequenceClassification,
"DeBertaForTokenClassification" -> DeBertaForTokenClassification,
"CamemBertEmbeddings" -> CamemBertEmbeddings,
"AlbertForQuestionAnswering" -> AlbertForQuestionAnswering,
"BertForQuestionAnswering" -> BertForQuestionAnswering,
Expand Down Expand Up @@ -722,7 +732,8 @@ object PythonResourceDownloader {
readerStr: String,
name: String,
language: String = null,
remoteLoc: String = null): PipelineStage = {
remoteLoc: String = null,
preferredEngine: String): PipelineStage = {

val reader = keyToReader.getOrElse(
if (typeMapper.contains(readerStr)) typeMapper(readerStr) else readerStr,
Expand All @@ -734,7 +745,8 @@ object PythonResourceDownloader {
reader.asInstanceOf[DefaultParamsReadable[PipelineStage]],
name,
Option(language),
correctedFolder)
correctedFolder,
preferredEngine)

// Cast the model to the required type. This has to be done for each entry in the typeMapper map
if (typeMapper.contains(readerStr) && readerStr == "ZeroShotNerModel")
Expand Down Expand Up @@ -800,8 +812,19 @@ object PythonResourceDownloader {
ResourceDownloader.listAvailableAnnotators().mkString("\n")
}

def getDownloadSize(name: String, language: String = "en", remoteLoc: String = null): String = {
def getDownloadSize(
name: String,
language: String = "en",
remoteLoc: String = null,
annotator: String,
engine: String): String = {
val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
ResourceDownloader.getDownloadSize(ResourceRequest(name, Option(language), correctedFolder))
ResourceDownloader.getDownloadSize(
ResourceRequest(
name,
Option(language),
correctedFolder,
annotator = Some(annotator),
engine = Some(engine)))
}
}
Loading
Loading