Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export function lowercase(): Normalizer
export function replace(pattern: string, content: string): Normalizer
export function nmt(): Normalizer
export function precompiled(bytes: Array<number>): Normalizer
export function unicodeFilter(filterUnassigned?: boolean | undefined | null, filterPrivateUse?: boolean | undefined | null): Normalizer
export const enum JsSplitDelimiterBehavior {
Removed = 'Removed',
Isolated = 'Isolated',
Expand Down
2 changes: 2 additions & 0 deletions bindings/node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ const {
replace,
nmt,
precompiled,
unicodeFilter,
JsSplitDelimiterBehavior,
PreTokenizer,
byteLevelPreTokenizer,
Expand Down Expand Up @@ -363,6 +364,7 @@ module.exports.lowercase = lowercase
module.exports.replace = replace
module.exports.nmt = nmt
module.exports.precompiled = precompiled
module.exports.unicodeFilter = unicodeFilter
module.exports.JsSplitDelimiterBehavior = JsSplitDelimiterBehavior
module.exports.PreTokenizer = PreTokenizer
module.exports.byteLevelPreTokenizer = byteLevelPreTokenizer
Expand Down
35 changes: 34 additions & 1 deletion bindings/node/lib/bindings/normalizers.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { prependNormalizer, stripAccentsNormalizer, stripNormalizer } from '../../'
import { prependNormalizer, stripAccentsNormalizer, stripNormalizer, unicodeFilter } from '../../'

describe('stripNormalizer', () => {
it('instantiates with no parameters', () => {
Expand Down Expand Up @@ -42,3 +42,36 @@ describe('stripAccentsNormalizer', () => {
expect(normalizer.constructor.name).toEqual('Normalizer')
})
})

describe('unicodeFilter', () => {
it('instantiates with defaults', () => {
const normalizer = unicodeFilter()
expect(normalizer.constructor.name).toEqual('Normalizer')
})

it('handles default filtering', () => {
const normalizer = unicodeFilter() // Default filters out Unassigned, PrivateUse
const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF)
expect(normalizer.normalizeString(input)).toEqual('Hello')
})

it('accepts custom filter options', () => {
// Only filter private use areas
const normalizer = unicodeFilter(false, true)
const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF)
const expected = 'Hello' + String.fromCodePoint(0x10FFFF)
expect(normalizer.normalizeString(input)).toEqual(expected)
})

it('accepts undefined options', () => {
const normalizer = unicodeFilter(undefined, undefined)
const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF)
expect(normalizer.normalizeString(input)).toEqual('Hello')
})

it('can disable all filtering', () => {
const normalizer = unicodeFilter(false, false)
const input = 'Hello' + String.fromCharCode(0xE000) + String.fromCodePoint(0xF0000) + String.fromCodePoint(0x10FFFF)
expect(normalizer.normalizeString(input)).toEqual(input)
})
})
12 changes: 12 additions & 0 deletions bindings/node/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,15 @@ pub fn precompiled(bytes: Vec<u8>) -> Result<Normalizer> {
))),
})
}

#[napi]
pub fn unicode_filter(filter_unassigned: Option<bool>, filter_private_use: Option<bool>) -> Normalizer {
let filter = tk::normalizers::UnicodeFilter::new(
filter_unassigned.unwrap_or(true),
filter_private_use.unwrap_or(true)
);

Normalizer {
normalizer: Some(Arc::new(RwLock::new(filter.into()))),
}
}
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/normalizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Precompiled = normalizers.Precompiled
Replace = normalizers.Replace
ByteLevel = normalizers.ByteLevel
UnicodeFilter = normalizers.UnicodeFilter

NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}

Expand Down
24 changes: 23 additions & 1 deletion bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::{
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace,
Strip, StripAccents, NFC, NFD, NFKC, NFKD,
Strip, StripAccents, NFC, NFD, NFKC, NFKD, UnicodeFilter,
};
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
Expand Down Expand Up @@ -126,6 +126,10 @@ impl PyNormalizer {
.into_pyobject(py)?
.into_any()
.into(),
NormalizerWrapper::UnicodeFilter(_) => Py::new(py, (PyUnicodeFilter {}, base))?
.into_pyobject(py)?
.into_any()
.into(),
},
},
})
Expand Down Expand Up @@ -794,6 +798,23 @@ impl Normalizer for PyNormalizerWrapper {
}
}

/// UnicodeFilter normalizer that filters out unwanted unicode categories
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "UnicodeFilter")]
pub struct PyUnicodeFilter {}

#[pymethods]
impl PyUnicodeFilter {
#[new]
#[pyo3(signature = (filter_unassigned = None, filter_private_use = None), text_signature = "(self, filter_unassigned = True, filter_private_use = True)")]
fn new(filter_unassigned: Option<bool>, filter_private_use: Option<bool>) -> (Self, PyNormalizer) {
let filter = tk::normalizers::UnicodeFilter::new(
filter_unassigned.unwrap_or(true),
filter_private_use.unwrap_or(true)
);
(PyUnicodeFilter {}, filter.into())
}
}

/// Normalizers Module
#[pymodule]
pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
Expand All @@ -812,6 +833,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyNmt>()?;
m.add_class::<PyPrecompiled>()?;
m.add_class::<PyReplace>()?;
m.add_class::<PyUnicodeFilter>()?;
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::exceptions;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::Python;
use serde::ser::SerializeStruct;
use serde::Deserializer;
use serde::Serializer;
Expand Down
35 changes: 35 additions & 0 deletions bindings/python/tests/bindings/test_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Strip,
Prepend,
Replace,
UnicodeFilter,
)


Expand Down Expand Up @@ -201,6 +202,40 @@ def test_can_modify(self):
assert normalizer.prepend == "-"


class TestUnicodeFilter:
def test_instantiate(self):
assert isinstance(UnicodeFilter(), Normalizer)
assert isinstance(UnicodeFilter(), UnicodeFilter)
assert isinstance(pickle.loads(pickle.dumps(UnicodeFilter())), UnicodeFilter)

def test_default_filtering(self):
normalizer = UnicodeFilter() # Default filters out Unassigned, PrivateUse
output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF") # Hello + Private Use + Private Use B + Unassigned
assert output == "Hello" # Only valid chars remain

def test_custom_filtering(self):
# Only filter private use areas
normalizer = UnicodeFilter(
filter_unassigned=False,
filter_private_use=True,
)
output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF")
assert output == "Hello\U0010FFFF" # Private use removed, others kept

def test_can_modify(self):
normalizer = UnicodeFilter()
output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF")
assert output == "Hello" # All filtered by default

# Disable all filtering
normalizer = UnicodeFilter(
filter_unassigned=False,
filter_private_use=False,
)
output = normalizer.normalize_str("Hello\uE000\U000F0000\U0010FFFF")
assert output == "Hello\uE000\U000F0000\U0010FFFF" # Nothing filtered


class TestCustomNormalizer:
class BadCustomNormalizer:
def normalize(self, normalized, wrong):
Expand Down
1 change: 1 addition & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ serde_json = "1.0"
unicode-normalization-alignments = "0.1"
unicode_categories = "0.1"
unicode-segmentation = "1.11"
unicode-general-category = "0.6.0"
indicatif = {version = "0.17", optional = true}
itertools = "0.14"
log = "0.4"
Expand Down
11 changes: 10 additions & 1 deletion tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use crate::normalizers::precompiled::Precompiled;
pub use crate::normalizers::prepend::Prepend;
pub use crate::normalizers::replace::Replace;
pub use crate::normalizers::strip::{Strip, StripAccents};
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD, UnicodeFilter};
pub use crate::normalizers::utils::{Lowercase, Sequence};
use serde::{Deserialize, Deserializer, Serialize};

Expand All @@ -36,6 +36,7 @@ pub enum NormalizerWrapper {
Replace(Replace),
Prepend(Prepend),
ByteLevel(ByteLevel),
UnicodeFilter(UnicodeFilter),
}

impl<'de> Deserialize<'de> for NormalizerWrapper {
Expand Down Expand Up @@ -66,6 +67,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Replace,
Prepend,
ByteLevel,
UnicodeFilter,
}

#[derive(Deserialize)]
Expand All @@ -92,6 +94,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
Replace(Replace),
Prepend(Prepend),
ByteLevel(ByteLevel),
UnicodeFilter(UnicodeFilter),
}

let helper = NormalizerHelper::deserialize(deserializer)?;
Expand Down Expand Up @@ -151,6 +154,9 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::UnicodeFilter => NormalizerWrapper::UnicodeFilter(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}

Expand All @@ -175,6 +181,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper {
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
NormalizerUntagged::UnicodeFilter(uf) => NormalizerWrapper::UnicodeFilter(uf),
}
}
})
Expand All @@ -198,6 +205,7 @@ impl Normalizer for NormalizerWrapper {
Self::Replace(lc) => lc.normalize(normalized),
Self::Prepend(lc) => lc.normalize(normalized),
Self::ByteLevel(lc) => lc.normalize(normalized),
Self::UnicodeFilter(uf) => uf.normalize(normalized),
}
}
}
Expand All @@ -216,6 +224,7 @@ impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
impl_enum_from!(Replace, NormalizerWrapper, Replace);
impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);
impl_enum_from!(UnicodeFilter, NormalizerWrapper, UnicodeFilter);

#[cfg(test)]
mod tests {
Expand Down
57 changes: 57 additions & 0 deletions tokenizers/src/normalizers/unicode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,67 @@ impl Normalizer for Nmt {
}
}

use unicode_general_category::{get_general_category, GeneralCategory};
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct UnicodeFilter {
filter_unassigned: bool,
filter_private_use: bool,
}

impl Default for UnicodeFilter {
fn default() -> Self {
Self {
filter_unassigned: true,
filter_private_use: true,
}
}
}

impl UnicodeFilter {
/// Filters unicode characters based on their general category.
/// Args:
/// filter_unassigned: Whether to filter out unassigned unicode characters
/// filter_private_use: Whether to filter out private use unicode characters
pub fn new(filter_unassigned: bool, filter_private_use: bool) -> Self {
Self {
filter_unassigned,
filter_private_use,
}
}
}

impl Normalizer for UnicodeFilter {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
normalized.filter(|c| {
let category = get_general_category(c);
!(self.filter_unassigned && category == GeneralCategory::Unassigned ||
self.filter_private_use && category == GeneralCategory::PrivateUse)
});
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_unicode_filter() {
// Test with default settings (filter all categories)
let original = "A\u{20AC}\u{10FFFF}\u{E000}".to_string(); // Regular + Euro + Unassigned + Private Use
let normalized = "A\u{20AC}".to_string(); // Keep only valid chars
let mut n = NormalizedString::from(original.clone());
UnicodeFilter::default().normalize(&mut n).unwrap();
assert_eq!(n.get(), normalized);

// Test with only filtering unassigned
let mut n = NormalizedString::from(original);
UnicodeFilter::new(true, false).normalize(&mut n).unwrap();
assert_eq!(n.get(), format!("A\u{20AC}\u{E000}")); // Keep private use, filter unassigned
}

#[test]
fn test_nfkc() {
let original = "\u{fb01}".to_string();
Expand Down