diff --git a/bindings/node/src/normalizers.rs b/bindings/node/src/normalizers.rs index f3009ed1f..d22ce8cac 100644 --- a/bindings/node/src/normalizers.rs +++ b/bindings/node/src/normalizers.rs @@ -44,6 +44,15 @@ impl tk::Normalizer for Normalizer { } } +#[napi] +pub fn append_normalizer(append: String) -> Normalizer { + Normalizer { + normalizer: Some(Arc::new(RwLock::new( + tk::normalizers::append::Append::new(append).into(), + ))), + } +} + #[napi] pub fn prepend_normalizer(prepend: String) -> Normalizer { Normalizer { diff --git a/bindings/python/py_src/tokenizers/normalizers/__init__.py b/bindings/python/py_src/tokenizers/normalizers/__init__.py index 86d233bd2..9bb6e3ce4 100644 --- a/bindings/python/py_src/tokenizers/normalizers/__init__.py +++ b/bindings/python/py_src/tokenizers/normalizers/__init__.py @@ -9,6 +9,7 @@ NFKC = normalizers.NFKC Sequence = normalizers.Sequence Lowercase = normalizers.Lowercase +Append = normalizers.Append Prepend = normalizers.Prepend Strip = normalizers.Strip StripAccents = normalizers.StripAccents diff --git a/bindings/python/py_src/tokenizers/normalizers/__init__.pyi b/bindings/python/py_src/tokenizers/normalizers/__init__.pyi index 1f5555104..b86252d1c 100644 --- a/bindings/python/py_src/tokenizers/normalizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/normalizers/__init__.pyi @@ -428,6 +428,47 @@ class Precompiled(Normalizer): """ pass +class Append(Normalizer): + """ + Append normalizer + """ + def __init__(self, append): + pass + + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + class Prepend(Normalizer): """ Prepend normalizer diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index c5697f75b..a066cb1e8 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -8,7 +8,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ - BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, + BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Append, Prepend, Replace, Strip, StripAccents, NFC, NFD, NFKC, NFKD, }; use tk::{NormalizedString, Normalizer}; @@ -82,6 +82,10 @@ impl PyNormalizer { .into_pyobject(py)? .into_any() .into(), + NormalizerWrapper::Append(_) => Py::new(py, (PyAppend {}, base))? + .into_pyobject(py)? + .into_any() + .into(), NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))? .into_pyobject(py)? .into_any() @@ -512,6 +516,28 @@ impl PyStrip { } } +/// Append normalizer +#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Append")] +pub struct PyAppend {} +#[pymethods] +impl PyAppend { + #[getter] + fn get_append(self_: PyRef) -> String { + getter!(self_, Append, append) + } + + #[setter] + fn set_append(self_: PyRef, append: String) { + setter!(self_, Append, append, append) + } + + #[new] + #[pyo3(signature = (append="▁".to_string()), text_signature = "(self, append)")] + fn new(append: String) -> (Self, PyNormalizer) { + (PyAppend {}, Append::new(append).into()) + } +} + /// Prepend normalizer #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Prepend")] pub struct PyPrepend {} @@ -807,6 +833,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tokenizers/src/normalizers/append.rs b/tokenizers/src/normalizers/append.rs new file mode 100644 index 000000000..e7c266224 --- /dev/null +++ b/tokenizers/src/normalizers/append.rs @@ -0,0 +1,40 @@ +use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] +pub struct Append { + pub append: String, +} + +impl Append { + pub fn new(append: String) -> Self { + Self { append } + } +} + +impl Normalizer for Append { + /// Append the normalized string inplace + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + if !normalized.is_empty() { + normalized.append(&self.append); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_append() { + let original = "Hello"; + let normalized = "Hello▁"; + assert_ne!(original, normalized); + let mut n = NormalizedString::from(original); + let append = Append::new("▁".to_string()); + append.normalize(&mut n).unwrap(); + assert_eq!(&n.get(), &normalized); + } +} diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index f400f13da..d33dc202a 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -1,6 +1,7 @@ pub mod bert; pub mod byte_level; pub mod precompiled; +pub mod append; pub mod prepend; pub mod replace; pub mod strip; @@ -9,6 +10,7 @@ pub mod utils; pub use crate::normalizers::bert::BertNormalizer; pub use crate::normalizers::byte_level::ByteLevel; pub use crate::normalizers::precompiled::Precompiled; +pub use crate::normalizers::append::Append; pub use crate::normalizers::prepend::Prepend; pub use crate::normalizers::replace::Replace; pub use crate::normalizers::strip::{Strip, StripAccents}; @@ -34,6 +36,7 @@ pub enum NormalizerWrapper { Nmt(Nmt), Precompiled(Precompiled), Replace(Replace), + Append(Append), Prepend(Prepend), ByteLevel(ByteLevel), } @@ -64,6 +67,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { Nmt, Precompiled, Replace, + Append, Prepend, ByteLevel, } @@ -90,6 +94,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { Nmt(Nmt), Precompiled(Precompiled), Replace(Replace), + Append(Append), Prepend(Prepend), ByteLevel(ByteLevel), } @@ -145,6 +150,9 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { EnumType::Replace => NormalizerWrapper::Replace( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), + EnumType::Append => NormalizerWrapper::Append( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), EnumType::Prepend => NormalizerWrapper::Prepend( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), @@ -173,6 +181,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe), NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe), NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe), + NormalizerUntagged::Append(bpe) => NormalizerWrapper::Append(bpe), NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe), NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe), } @@ -196,6 +205,7 @@ impl Normalizer for NormalizerWrapper { Self::Nmt(lc) => lc.normalize(normalized), Self::Precompiled(lc) => lc.normalize(normalized), Self::Replace(lc) => lc.normalize(normalized), + Self::Append(lc) => lc.normalize(normalized), Self::Prepend(lc) => lc.normalize(normalized), Self::ByteLevel(lc) => lc.normalize(normalized), } @@ -214,6 +224,7 @@ impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase); impl_enum_from!(Nmt, NormalizerWrapper, Nmt); impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled); impl_enum_from!(Replace, NormalizerWrapper, Replace); +impl_enum_from!(Append, NormalizerWrapper, Append); impl_enum_from!(Prepend, NormalizerWrapper, Prepend); impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel); @@ -239,6 +250,13 @@ mod tests { _ => panic!("Expected an error here"), } + let json = r#"{"append":"a"}"#; + let reconstructed = serde_json::from_str::(json); + assert!(matches!( + reconstructed.unwrap(), + NormalizerWrapper::Append(_) + )); + let json = r#"{"prepend":"a"}"#; let reconstructed = serde_json::from_str::(json); assert!(matches!(