Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bindings/node/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ impl tk::Normalizer for Normalizer {
}
}

#[napi]
pub fn append_normalizer(append: String) -> Normalizer {
Normalizer {
normalizer: Some(Arc::new(RwLock::new(
tk::normalizers::append::Append::new(append).into(),
))),
}
}

#[napi]
pub fn prepend_normalizer(prepend: String) -> Normalizer {
Normalizer {
Expand Down
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/normalizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NFKC = normalizers.NFKC
Sequence = normalizers.Sequence
Lowercase = normalizers.Lowercase
Append = normalizers.Append
Prepend = normalizers.Prepend
Strip = normalizers.Strip
StripAccents = normalizers.StripAccents
Expand Down
41 changes: 41 additions & 0 deletions bindings/python/py_src/tokenizers/normalizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,47 @@ class Precompiled(Normalizer):
"""
pass

class Append(Normalizer):
"""
Append normalizer
"""
def __init__(self, append):
pass

def normalize(self, normalized):
"""
Normalize a :class:`~tokenizers.NormalizedString` in-place

This method allows to modify a :class:`~tokenizers.NormalizedString` to
keep track of the alignment information. If you just want to see the result
of the normalization on a raw string, you can use
:meth:`~tokenizers.normalizers.Normalizer.normalize_str`

Args:
normalized (:class:`~tokenizers.NormalizedString`):
The normalized string on which to apply this
:class:`~tokenizers.normalizers.Normalizer`
"""
pass

def normalize_str(self, sequence):
"""
Normalize the given string

This method provides a way to visualize the effect of a
:class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment
information. If you need to get/convert offsets, you can use
:meth:`~tokenizers.normalizers.Normalizer.normalize`

Args:
sequence (:obj:`str`):
A string to normalize

Returns:
:obj:`str`: A string after normalization
"""
pass

class Prepend(Normalizer):
"""
Prepend normalizer
Expand Down
29 changes: 28 additions & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::{
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace,
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Append, Prepend, Replace,
Strip, StripAccents, NFC, NFD, NFKC, NFKD,
};
use tk::{NormalizedString, Normalizer};
Expand Down Expand Up @@ -82,6 +82,10 @@ impl PyNormalizer {
.into_pyobject(py)?
.into_any()
.into(),
NormalizerWrapper::Append(_) => Py::new(py, (PyAppend {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))?
.into_pyobject(py)?
.into_any()
Expand Down Expand Up @@ -514,6 +518,28 @@ impl PyStrip {
}
}

/// Append normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Append")]
pub struct PyAppend {}
#[pymethods]
impl PyAppend {
#[getter]
fn get_append(self_: PyRef<Self>) -> String {
getter!(self_, Append, append)
}

#[setter]
fn set_append(self_: PyRef<Self>, append: String) {
setter!(self_, Append, append, append)
}

#[new]
#[pyo3(signature = (append="▁".to_string()), text_signature = "(self, append)")]
fn new(append: String) -> (Self, PyNormalizer) {
(PyAppend {}, Append::new(append).into())
}
}

/// Prepend normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Prepend")]
pub struct PyPrepend {}
Expand Down Expand Up @@ -810,6 +836,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyLowercase>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyStripAccents>()?;
m.add_class::<PyAppend>()?;
m.add_class::<PyPrepend>()?;
m.add_class::<PyByteLevel>()?;
m.add_class::<PyNmt>()?;
Expand Down
40 changes: 40 additions & 0 deletions tokenizers/src/normalizers/append.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct Append {
pub append: String,
}

impl Append {
pub fn new(append: String) -> Self {
Self { append }
}
}

impl Normalizer for Append {
/// Append the normalized string inplace
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
if !normalized.is_empty() {
normalized.append(&self.append);
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_append() {
let original = "Hello";
let normalized = "Hello▁";
assert_ne!(original, normalized);
let mut n = NormalizedString::from(original);
let append = Append::new("▁".to_string());
append.normalize(&mut n).unwrap();
assert_eq!(&n.get(), &normalized);
}
}
18 changes: 18 additions & 0 deletions tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bert;
pub mod byte_level;
pub mod precompiled;
pub mod append;
pub mod prepend;
pub mod replace;
pub mod strip;
Expand All @@ -9,6 +10,7 @@ pub mod utils;
pub use crate::normalizers::bert::BertNormalizer;
pub use crate::normalizers::byte_level::ByteLevel;
pub use crate::normalizers::precompiled::Precompiled;
pub use crate::normalizers::append::Append;
pub use crate::normalizers::prepend::Prepend;
pub use crate::normalizers::replace::Replace;
pub use crate::normalizers::strip::{Strip, StripAccents};
Expand All @@ -34,6 +36,7 @@ pub enum NormalizerWrapper {
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Append(Append),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
Expand Down Expand Up @@ -64,6 +67,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Nmt,
Precompiled,
Replace,
Append,
Prepend,
ByteLevel,
}
Expand All @@ -90,6 +94,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Append(Append),
Prepend(Prepend),
ByteLevel(ByteLevel),
}
Expand Down Expand Up @@ -145,6 +150,9 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
EnumType::Replace => NormalizerWrapper::Replace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Append => NormalizerWrapper::Append(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Prepend => NormalizerWrapper::Prepend(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
Expand Down Expand Up @@ -173,6 +181,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
NormalizerUntagged::Append(bpe) => NormalizerWrapper::Append(bpe),
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
}
Expand All @@ -196,6 +205,7 @@ impl Normalizer for NormalizerWrapper {
Self::Nmt(lc) => lc.normalize(normalized),
Self::Precompiled(lc) => lc.normalize(normalized),
Self::Replace(lc) => lc.normalize(normalized),
Self::Append(lc) => lc.normalize(normalized),
Self::Prepend(lc) => lc.normalize(normalized),
Self::ByteLevel(lc) => lc.normalize(normalized),
}
Expand All @@ -214,6 +224,7 @@ impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
impl_enum_from!(Replace, NormalizerWrapper, Replace);
impl_enum_from!(Append, NormalizerWrapper, Append);
impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);

Expand All @@ -239,6 +250,13 @@ mod tests {
_ => panic!("Expected an error here"),
}

let json = r#"{"append":"a"}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
reconstructed.unwrap(),
NormalizerWrapper::Append(_)
));

let json = r#"{"prepend":"a"}"#;
let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
assert!(matches!(
Expand Down