55import h5py
66import numpy as np
77import scipy .io
8+ from selene_sdk .samplers .samples_batch import SamplesBatch
89
910from .file_sampler import FileSampler
1011
@@ -126,8 +127,8 @@ def sample(self, batch_size=1):
126127
127128 Returns
128129 -------
129- sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
130- A tuple containing the numeric representation of the
130+ SamplesBatch
131+ A batch containing the numeric representation of the
131132 sequence examples and their corresponding labels. The
132133 shape of `sequences` will be
133134 :math:`B \\ times L \\ times N`, where :math:`B` is
@@ -166,8 +167,8 @@ def sample(self, batch_size=1):
166167 targets = self ._sample_tgts [:, use_indices ].astype (float )
167168 targets = np .transpose (
168169 targets , (1 , 0 ))
169- return (sequences , targets )
170- return sequences ,
170+ return SamplesBatch (sequences , target_batch = targets )
171+ return SamplesBatch ( sequences )
171172
172173 def get_data (self , batch_size , n_samples = None ):
173174 """
@@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None):
190191 is `batch_size`, :math:`L` is the sequence length,
191192 and :math:`N` is the size of the sequence type's alphabet.
192193 """
194+ # TODO: Should this method return a collection of samples_batch.inputs()?
195+
193196 if not n_samples :
194197 n_samples = self .n_samples
195198 sequences = []
196199
197200 count = batch_size
198201 while count < n_samples :
199- seqs , = self .sample (batch_size = batch_size )
200- sequences .append (seqs )
202+ samples_batch = self .sample (batch_size = batch_size )
203+ sequences .append (samples_batch . sequence_batch () )
201204 count += batch_size
202205 remainder = batch_size - (count - n_samples )
203- seqs , = self .sample (batch_size = remainder )
204- sequences .append (seqs )
206+ samples_batch = self .sample (batch_size = remainder )
207+ sequences .append (samples_batch . sequence_batch () )
205208 return sequences
206209
207210 def get_data_and_targets (self , batch_size , n_samples = None ):
@@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
218221
219222 Returns
220223 -------
221- sequences_and_targets , targets_matrix : \
222- tuple(list(tuple(numpy.ndarray, numpy.ndarray) ), numpy.ndarray)
223- Tuple containing the list of sequence-target pairs , as well
224+ batches , targets_matrix : \
225+ tuple(list(SamplesBatch ), numpy.ndarray)
226+ Tuple containing the list of batches , as well
224227 as a single matrix with all targets in the same order.
225- Note that `sequences_and_targets `'s sequence elements are of
228+ Note that `batches `'s sequence elements are of
226229 the shape :math:`B \\ times L \\ times N` and its target
227230 elements are of the shape :math:`B \\ times F`, where
228231 :math:`B` is `batch_size`, :math:`L` is the sequence length,
@@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None):
237240 "initialization. Please use `get_data` instead." )
238241 if not n_samples :
239242 n_samples = self .n_samples
240- sequences_and_targets = []
243+ batches = []
241244 targets_mat = []
242245
243246 count = batch_size
244247 while count < n_samples :
245- seqs , tgts = self .sample (batch_size = batch_size )
246- sequences_and_targets .append (( seqs , tgts ) )
247- targets_mat .append (tgts )
248+ samples_batch = self .sample (batch_size = batch_size )
249+ batches .append (samples_batch )
250+ targets_mat .append (samples_batch . targets () )
248251 count += batch_size
249252 remainder = batch_size - (count - n_samples )
250- seqs , tgts = self .sample (batch_size = remainder )
251- sequences_and_targets .append (( seqs , tgts ) )
252- targets_mat .append (tgts )
253+ samples_batch = self .sample (batch_size = remainder )
254+ batches .append (samples_batch )
255+ targets_mat .append (samples_batch . targets () )
253256 # TODO: should not assume targets are always integers
254257 targets_mat = np .vstack (targets_mat ).astype (float )
255- return sequences_and_targets , targets_mat
258+ return batches , targets_mat
0 commit comments