1313import copy
1414from itertools import groupby
1515from collections import namedtuple
16- import six
1716from six .moves import zip
1817
19- from lxml .etree import XPathEvaluator , Comment
18+ from lxml .etree import Comment , iterwalk
2019
2120from webstruct .sequence_encoding import IobEncoder
22- from webstruct .text_tokenizers import tokenize
21+ from webstruct .text_tokenizers import tokenize , TextToken
2322from webstruct .utils import (
2423 replace_html_tags ,
2524 kill_html_tags ,
26- smart_join ,
2725)
2826
2927
30- _HtmlToken = namedtuple ('HtmlToken' , 'index tokens elem is_tail' )
28+ _HtmlToken = namedtuple ('HtmlToken' , ['index' ,
29+ 'tokens' ,
30+ 'elem' ,
31+ 'is_tail' ,
32+ 'position' ,
33+ 'length' ])
3134
3235
3336class HtmlToken (_HtmlToken ):
@@ -41,6 +44,8 @@ class HtmlToken(_HtmlToken):
4144 * :attr:`elem` is the current html block (as lxml's Element) - most
4245 likely you want :attr:`parent` instead of it
4346 * :attr:`is_tail` flag indicates that token belongs to element tail
47+ * :attr:`position` is logical position(in letters or codepoints) of token start in parent text
48+ * :attr:`length` is logical length(in letters or codepoints) of token in parent text
4449
4550 Computed properties:
4651
@@ -64,8 +69,10 @@ def root(self):
6469 return self .elem .getroottree ()
6570
6671 def __repr__ (self ):
67- return "HtmlToken(token=%r, parent=%r, index=%s)" % (
68- self .token , self .parent , self .index
72+ return ("HtmlToken("
73+ "token=%r, parent=%r, index=%s, position=%d, length=%d"
74+ ")" ) % (
75+ self .token , self .parent , self .index , self .position , self .length
6976 )
7077
7178
@@ -85,7 +92,8 @@ class HtmlTokenizer(object):
8592 ----------
8693
8794 tagset : set, optional
88- A set of entity types to keep. If not passed, all entity types are kept.
95+ A set of entity types to keep.
96+ If not passed, all entity types are kept.
8997 Use this argument to discard some entity types from training data.
9098 sequence_encoder : object, optional
9199 Sequence encoder object. If not passed,
@@ -142,7 +150,7 @@ def tokenize_single(self, tree):
142150 >>> tree = loader.loadbytes(b"<p>hello, <PER>John <b>Doe</b></PER> <br> <PER>Mary</PER> said</p>")
143151 >>> html_tokens, tags = html_tokenizer.tokenize_single(tree)
144152 >>> html_tokens
145- [HtmlToken(token='hello', parent=<Element p at ...>, index=0), HtmlToken...]
153+ [HtmlToken(token='hello', parent=<Element p at ...>, index=0, ... ), HtmlToken...]
146154 >>> tags
147155 ['O', 'B-PER', 'I-PER', 'B-PER', 'O']
148156 >>> for tok, iob_tag in zip(html_tokens, tags):
@@ -180,6 +188,8 @@ def detokenize_single(self, html_tokens, tags):
180188 Build annotated ``lxml.etree.ElementTree`` from
181189 ``html_tokens`` (a list of :class:`.HtmlToken` instances)
182190 and ``tags`` (a list of their tags).
191+ **ATTENTION**: ``html_tokens`` should be tokenized from tree
192+ without tags
183193
184194 Annotations are encoded as ``__START_TAG__`` and ``__END_TAG__``
185195 text tokens (this is the format :mod:`webstruct.loaders` use).
@@ -190,9 +200,7 @@ def detokenize_single(self, html_tokens, tags):
190200 if not html_tokens :
191201 return None
192202
193- orig_tree = html_tokens [0 ].root
194- tree = copy .deepcopy (orig_tree )
195- xpatheval = XPathEvaluator (tree )
203+ tree = html_tokens [0 ].root
196204
197205 # find starts/ends of token groups
198206 token_groups = self .sequence_encoder .group (zip (html_tokens , tags ))
@@ -206,30 +214,49 @@ def detokenize_single(self, html_tokens, tags):
206214 pos += n_tokens
207215
208216 # mark starts/ends with special tokens
209- data = zip (html_tokens , tags , range (len (html_tokens )))
210- keyfunc = lambda rec : (rec [0 ].elem , rec [0 ].is_tail )
217+ data = [(s , True ) for s in starts ]
218+ data .extend ((s , False ) for s in ends )
219+ keyfunc = lambda rec : (id (html_tokens [rec [0 ]].elem ), html_tokens [rec [0 ]].is_tail )
220+ data .sort (key = keyfunc )
211221
212- for (orig_elem , is_tail ), g in groupby (data , keyfunc ):
222+ for (_ , is_tail ), g in groupby (data , keyfunc ):
213223 g = list (g )
214- fix = False
215- tokens = g [0 ][0 ].tokens [:]
216- for token , tag , token_idx in g :
217- if token_idx in starts :
218- text = ' __START_%s__ %s' % (tag [2 :], tokens [token .index ])
219- tokens [token .index ] = text
220- fix = True
221- if token_idx in ends :
222- text = '%s __END_%s__ ' % (tokens [token .index ], tag [2 :])
223- tokens [token .index ] = text
224- fix = True
225-
226- if fix :
227- xpath = orig_tree .getpath (orig_elem )
228- elem = xpatheval (xpath )[0 ]
229- if is_tail :
230- elem .tail = smart_join (tokens )
224+ g .sort (key = lambda t : (html_tokens [t [0 ]].position , not t [1 ]))
225+
226+ if not g :
227+ continue
228+
229+ elem = html_tokens [g [0 ][0 ]].elem
230+
231+ pos_in_source = 0
232+ source = elem .text
233+ if is_tail :
234+ source = elem .tail
235+
236+ mods = list ()
237+
238+ for idx , is_starts in g :
239+ token = html_tokens [idx ]
240+ tag = tags [idx ]
241+ mods .append (source [pos_in_source :token .position ])
242+ pos_in_source = token .position
243+ if is_starts :
244+ patch = ' __START_%s__ ' % (tag [2 :],)
245+ mods .append (patch )
231246 else :
232- elem .text = smart_join (tokens )
247+ end_in_source = pos_in_source + token .length
248+ mods .append (source [pos_in_source :end_in_source ])
249+ pos_in_source = pos_in_source + token .length
250+ patch = ' __END_%s__ ' % (tag [2 :],)
251+ mods .append (patch )
252+
253+ mods .append (source [pos_in_source :])
254+ modded = '' .join (mods )
255+
256+ if is_tail :
257+ elem .tail = modded
258+ else :
259+ elem .text = modded
233260
234261 return tree
235262
@@ -245,18 +272,35 @@ def _process_tree(self, tree):
245272 return
246273
247274 head_tokens , head_tags = self ._tokenize_and_split (tree .text )
275+ char_tokens = [t .chars for t in head_tokens ]
248276 for index , (token , tag ) in enumerate (zip (head_tokens , head_tags )):
249- yield HtmlToken (index , head_tokens , tree , False ), tag
277+ yield HtmlToken (index ,
278+ char_tokens ,
279+ tree ,
280+ False ,
281+ token .position ,
282+ token .length ), tag
250283
251284 for child in tree : # where is my precious "yield from"?
252285 for html_token , tag in self ._process_tree (child ):
253286 yield html_token , tag
254287
255288 tail_tokens , tail_tags = self ._tokenize_and_split (tree .tail )
289+ char_tokens = [t .chars for t in tail_tokens ]
256290 for index , (token , tag ) in enumerate (zip (tail_tokens , tail_tags )):
257- yield HtmlToken (index , tail_tokens , tree , True ), tag
291+ yield HtmlToken (index ,
292+ char_tokens ,
293+ tree ,
294+ True ,
295+ token .position ,
296+ token .length ), tag
297+
298+ def cleanup_tree (self , tree ):
299+ cleaned = copy .deepcopy (tree )
300+ for _ , elem in iterwalk (cleaned ):
301+ self ._cleanup_elem (elem )
258302
259- self . _cleanup_elem ( tree )
303+ return cleaned
260304
261305 def _cleanup_elem (self , elem ):
262306 """ Remove special tokens from elem """
@@ -266,16 +310,23 @@ def _cleanup_elem(self, elem):
266310 elem .tail = self ._tag_re .sub ("" , elem .tail )
267311
268312 def _tokenize_and_split (self , text ):
269- input_tokens = self ._limit_tags (self .text_tokenize_func (text or '' ))
270- input_tokens = map (six .text_type , input_tokens )
271- return self .sequence_encoder .encode_split (input_tokens )
313+ text = text or ''
314+ input_tokens = [t for t in self .text_tokenize_func (text )]
315+ input_tokens = self ._limit_tags (input_tokens )
316+ input_tokens = [TextToken (chars = t .chars ,
317+ position = t .position ,
318+ length = t .length ) for t in input_tokens ]
319+ chains = self .sequence_encoder .encode (t .chars for t in input_tokens )
320+ chains = self .sequence_encoder .from_indices (chains , input_tokens )
321+ chains = [l for l in chains ]
322+ return self .sequence_encoder .split (chains )
272323
273324 def _limit_tags (self , input_tokens ):
274325 if self .tagset is None :
275326 return input_tokens
276327
277328 proc = self .sequence_encoder .token_processor
278- token_classes = [proc .classify (tok ) for tok in input_tokens ]
329+ token_classes = [proc .classify (tok . chars ) for tok in input_tokens ]
279330 return [
280331 tok for (tok , (typ , value )) in zip (input_tokens , token_classes )
281332 if not (typ in {'start' , 'end' } and value not in self .tagset )
0 commit comments