diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts index 1252a2623..46b3e779a 100644 --- a/bindings/node/index.d.ts +++ b/bindings/node/index.d.ts @@ -77,6 +77,7 @@ export function lowercase(): Normalizer export function replace(pattern: string, content: string): Normalizer export function nmt(): Normalizer export function precompiled(bytes: Array): Normalizer +export function unicodeFilter(filterUnassigned?: boolean | undefined | null, filterPrivateUse?: boolean | undefined | null): Normalizer export const enum JsSplitDelimiterBehavior { Removed = 'Removed', Isolated = 'Isolated', diff --git a/bindings/node/index.js b/bindings/node/index.js index 8832d9c04..36537f20f 100644 --- a/bindings/node/index.js +++ b/bindings/node/index.js @@ -304,6 +304,7 @@ const { replace, nmt, precompiled, + unicodeFilter, JsSplitDelimiterBehavior, PreTokenizer, byteLevelPreTokenizer, @@ -363,6 +364,7 @@ module.exports.lowercase = lowercase module.exports.replace = replace module.exports.nmt = nmt module.exports.precompiled = precompiled +module.exports.unicodeFilter = unicodeFilter module.exports.JsSplitDelimiterBehavior = JsSplitDelimiterBehavior module.exports.PreTokenizer = PreTokenizer module.exports.byteLevelPreTokenizer = byteLevelPreTokenizer diff --git a/bindings/node/lib/bindings/normalizers.test.ts b/bindings/node/lib/bindings/normalizers.test.ts index e9f11fe92..f46e9b0e8 100644 --- a/bindings/node/lib/bindings/normalizers.test.ts +++ b/bindings/node/lib/bindings/normalizers.test.ts @@ -1,4 +1,4 @@ -import { prependNormalizer, stripAccentsNormalizer, stripNormalizer } from '../../' +import { prependNormalizer, stripAccentsNormalizer, stripNormalizer, unicodeFilter } from '../../' describe('stripNormalizer', () => { it('instantiates with no parameters', () => { @@ -42,3 +42,36 @@ describe('stripAccentsNormalizer', () => { expect(normalizer.constructor.name).toEqual('Normalizer') }) }) + +describe('unicodeFilter', () => { + it('instantiates with defaults', () => { + const normalizer = unicodeFilter() + expect(normalizer.constructor.name).toEqual('Normalizer') + }) + + it('handles default filtering', () => { + const normalizer = unicodeFilter() // Default filters out Unassigned, PrivateUse + const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF) + expect(normalizer.normalizeString(input)).toEqual('Hello') + }) + + it('accepts custom filter options', () => { + // Only filter private use areas + const normalizer = unicodeFilter(false, true) + const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF) + const expected = 'Hello' + String.fromCodePoint(0x10FFFF) + expect(normalizer.normalizeString(input)).toEqual(expected) + }) + + it('accepts undefined options', () => { + const normalizer = unicodeFilter(undefined, undefined) + const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF) + expect(normalizer.normalizeString(input)).toEqual('Hello') + }) + + it('can disable all filtering', () => { + const normalizer = unicodeFilter(false, false) + const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF) + expect(normalizer.normalizeString(input)).toEqual(input) + }) +}) diff --git a/bindings/node/src/normalizers.rs b/bindings/node/src/normalizers.rs index f3009ed1f..50d18c93f 100644 --- a/bindings/node/src/normalizers.rs +++ b/bindings/node/src/normalizers.rs @@ -186,3 +186,15 @@ pub fn precompiled(bytes: Vec) -> Result { ))), }) } + +#[napi] +pub fn unicode_filter(filter_unassigned: Option, filter_private_use: Option) -> Normalizer { + let filter = tk::normalizers::UnicodeFilter::new( + filter_unassigned.unwrap_or(true), + filter_private_use.unwrap_or(true) + ); + + Normalizer { + normalizer: Some(Arc::new(RwLock::new(filter.into()))), + } +} diff --git a/bindings/python/py_src/tokenizers/normalizers/__init__.py b/bindings/python/py_src/tokenizers/normalizers/__init__.py index 86d233bd2..36e69a274 100644 --- a/bindings/python/py_src/tokenizers/normalizers/__init__.py +++ b/bindings/python/py_src/tokenizers/normalizers/__init__.py @@ -16,6 +16,7 @@ Precompiled = normalizers.Precompiled Replace = normalizers.Replace ByteLevel = normalizers.ByteLevel +UnicodeFilter = normalizers.UnicodeFilter NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD} diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index c5697f75b..5e3153347 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -9,7 +9,7 @@ use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, - Strip, StripAccents, NFC, NFD, NFKC, NFKD, + Strip, StripAccents, NFC, NFD, NFKC, NFKD, UnicodeFilter, }; use tk::{NormalizedString, Normalizer}; use tokenizers as tk; @@ -126,6 +126,10 @@ impl PyNormalizer { .into_pyobject(py)? .into_any() .into(), + NormalizerWrapper::UnicodeFilter(_) => Py::new(py, (PyUnicodeFilter {}, base))? + .into_pyobject(py)? + .into_any() + .into(), }, }, }) @@ -794,6 +798,23 @@ impl Normalizer for PyNormalizerWrapper { } } +/// UnicodeFilter normalizer that filters out unwanted unicode categories +#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "UnicodeFilter")] +pub struct PyUnicodeFilter {} + +#[pymethods] +impl PyUnicodeFilter { + #[new] + #[pyo3(signature = (filter_unassigned = None, filter_private_use = None), text_signature = "(self, filter_unassigned = True, filter_private_use = True)")] + fn new(filter_unassigned: Option, filter_private_use: Option) -> (Self, PyNormalizer) { + let filter = tk::normalizers::UnicodeFilter::new( + filter_unassigned.unwrap_or(true), + filter_private_use.unwrap_or(true) + ); + (PyUnicodeFilter {}, filter.into()) + } +} + /// Normalizers Module #[pymodule] pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -812,6 +833,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 03fa6bdf7..c2f0e8526 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -8,6 +8,7 @@ use pyo3::exceptions; use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::*; +use pyo3::Python; use serde::ser::SerializeStruct; use serde::Deserializer; use serde::Serializer; diff --git a/bindings/python/tests/bindings/test_normalizers.py b/bindings/python/tests/bindings/test_normalizers.py index 99ab07d39..45280c361 100644 --- a/bindings/python/tests/bindings/test_normalizers.py +++ b/bindings/python/tests/bindings/test_normalizers.py @@ -12,6 +12,7 @@ Strip, Prepend, Replace, + UnicodeFilter, ) @@ -201,6 +202,40 @@ def test_can_modify(self): assert normalizer.prepend == "-" +class TestUnicodeFilter: + def test_instantiate(self): + assert isinstance(UnicodeFilter(), Normalizer) + assert isinstance(UnicodeFilter(), UnicodeFilter) + assert isinstance(pickle.loads(pickle.dumps(UnicodeFilter())), UnicodeFilter) + + def test_default_filtering(self): + normalizer = UnicodeFilter() # Default filters out Unassigned, PrivateUse + output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF") # Hello + Private Use + Private Use B + Unassigned + assert output == "Hello" # Only valid chars remain + + def test_custom_filtering(self): + # Only filter private use areas + normalizer = UnicodeFilter( + filter_unassigned=False, + filter_private_use=True, + ) + output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF") + assert output == "Hello\U0010FFFF" # Private use removed, others kept + + def test_can_modify(self): + normalizer = UnicodeFilter() + output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF") + assert output == "Hello" # All filtered by default + + # Disable all filtering + normalizer = UnicodeFilter( + filter_unassigned=False, + filter_private_use=False, + ) + output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF") + assert output == "Hello\uE000\U000F0000\U0010FFFF" # Nothing filtered + + class TestCustomNormalizer: class BadCustomNormalizer: def normalize(self, normalized, wrong): diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 9547c4e41..01c93026e 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -52,6 +52,7 @@ serde_json = "1.0" unicode-normalization-alignments = "0.1" unicode_categories = "0.1" unicode-segmentation = "1.11" +unicode-general-category = "0.6.0" indicatif = {version = "0.17", optional = true} itertools = "0.14" log = "0.4" diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index f400f13da..65d56e906 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -12,7 +12,7 @@ pub use crate::normalizers::precompiled::Precompiled; pub use crate::normalizers::prepend::Prepend; pub use crate::normalizers::replace::Replace; pub use crate::normalizers::strip::{Strip, StripAccents}; -pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; +pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD, UnicodeFilter}; pub use crate::normalizers::utils::{Lowercase, Sequence}; use serde::{Deserialize, Deserializer, Serialize}; @@ -36,6 +36,7 @@ pub enum NormalizerWrapper { Replace(Replace), Prepend(Prepend), ByteLevel(ByteLevel), + UnicodeFilter(UnicodeFilter), } impl<'de> Deserialize<'de> for NormalizerWrapper { @@ -66,6 +67,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { Replace, Prepend, ByteLevel, + UnicodeFilter, } #[derive(Deserialize)] @@ -92,6 +94,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { Replace(Replace), Prepend(Prepend), ByteLevel(ByteLevel), + UnicodeFilter(UnicodeFilter), } let helper = NormalizerHelper::deserialize(deserializer)?; @@ -151,6 +154,9 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { EnumType::ByteLevel => NormalizerWrapper::ByteLevel( serde_json::from_value(values).map_err(serde::de::Error::custom)?, ), + EnumType::UnicodeFilter => NormalizerWrapper::UnicodeFilter( + serde_json::from_value(values).map_err(serde::de::Error::custom)?, + ), } } @@ -175,6 +181,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe), NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe), NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe), + NormalizerUntagged::UnicodeFilter(uf) => NormalizerWrapper::UnicodeFilter(uf), } } }) @@ -198,6 +205,7 @@ impl Normalizer for NormalizerWrapper { Self::Replace(lc) => lc.normalize(normalized), Self::Prepend(lc) => lc.normalize(normalized), Self::ByteLevel(lc) => lc.normalize(normalized), + Self::UnicodeFilter(uf) => uf.normalize(normalized), } } } @@ -216,6 +224,7 @@ impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled); impl_enum_from!(Replace, NormalizerWrapper, Replace); impl_enum_from!(Prepend, NormalizerWrapper, Prepend); impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel); +impl_enum_from!(UnicodeFilter, NormalizerWrapper, UnicodeFilter); #[cfg(test)] mod tests { diff --git a/tokenizers/src/normalizers/unicode.rs b/tokenizers/src/normalizers/unicode.rs index 502b4239b..a2f18f652 100644 --- a/tokenizers/src/normalizers/unicode.rs +++ b/tokenizers/src/normalizers/unicode.rs @@ -82,10 +82,67 @@ impl Normalizer for Nmt { } } +use unicode_general_category::{get_general_category, GeneralCategory}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnicodeFilter { + filter_unassigned: bool, + filter_private_use: bool, +} + +impl Default for UnicodeFilter { + fn default() -> Self { + Self { + filter_unassigned: true, + filter_private_use: true, + } + } +} + +impl UnicodeFilter { + /// Filters unicode characters based on their general category. + /// Args: + /// filter_unassigned: Whether to filter out unassigned unicode characters + /// filter_private_use: Whether to filter out private use unicode characters + pub fn new(filter_unassigned: bool, filter_private_use: bool) -> Self { + Self { + filter_unassigned, + filter_private_use, + } + } +} + +impl Normalizer for UnicodeFilter { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + normalized.filter(|c| { + let category = get_general_category(c); + !(self.filter_unassigned && category == GeneralCategory::Unassigned || + self.filter_private_use && category == GeneralCategory::PrivateUse) + }); + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_unicode_filter() { + // Test with default settings (filter all categories) + let original = "A\u{20AC}\u{10FFFF}\u{E000}".to_string(); // Regular + Euro + Unassigned + Private Use + let normalized = "A\u{20AC}".to_string(); // Keep only valid chars + let mut n = NormalizedString::from(original.clone()); + UnicodeFilter::default().normalize(&mut n).unwrap(); + assert_eq!(n.get(), normalized); + + // Test with only filtering unassigned + let mut n = NormalizedString::from(original); + UnicodeFilter::new(true, false).normalize(&mut n).unwrap(); + assert_eq!(n.get(), format!("A\u{20AC}\u{E000}")); // Keep private use, filter unassigned + } + #[test] fn test_nfkc() { let original = "\u{fb01}".to_string();