forked from hqucms/NNTools
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhelper.py
More file actions
106 lines (96 loc) · 3.8 KB
/
helper.py
File metadata and controls
106 lines (96 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
'''
Helper functions for conversion.
@author: hqu
'''
import numpy as np
import logging
def xrd(filepath):
prefix = ''
if filepath.startswith('/eos/cms'):
prefix = 'root://eoscms.cern.ch/'
elif filepath.startswith('/eos/uscms'):
prefix = 'root://cmseos.fnal.gov/'
if prefix:
return prefix + '/' + filepath
else:
return filepath
def get_num_events(filepath, treename, selection=None):
import ROOT as rt
rt.gROOT.SetBatch(True)
import traceback
try:
f = rt.TFile.Open(filepath)
tree = f.Get(str(treename))
if not tree:
raise RuntimeError('Cannot find tree %s in file %s' % (treename, filepath))
if selection is None:
return tree.GetEntries()
else:
return tree.GetEntries(selection)
except:
logging.error('Error reading %s:\n%s' % (filepath, traceback.format_exc()))
return None
# borrowed from keras
# https://github.com/fchollet/keras/blob/master/keras/preprocessing/sequence.py
def pad_sequences(sequences, maxlen=None, dtype='int32',
padding='pre', truncating='pre', value=0.):
"""Pads each sequence to the same length (length of the longest sequence).
If maxlen is provided, any sequence longer
than maxlen is truncated to maxlen.
Truncation happens off either the beginning (default) or
the end of the sequence.
Supports post-padding and pre-padding (default).
# Arguments
sequences: list of lists where each element is a sequence
maxlen: int, maximum length
dtype: type to cast the resulting sequence.
padding: 'pre' or 'post', pad either before or after each sequence.
truncating: 'pre' or 'post', remove values from sequences larger than
maxlen either in the beginning or in the end of the sequence
value: float, value to pad the sequences to the desired value.
# Returns
x: numpy array with dimensions (number_of_sequences, maxlen)
# Raises
ValueError: in case of invalid values for `truncating` or `padding`,
or in case of invalid shape for a `sequences` entry.
"""
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
num_samples = len(sequences)
if maxlen is None:
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0:
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s):
continue # empty list/array was found
if truncating == 'post':
trunc = s[:maxlen]
elif truncating == 'pre':
trunc = s[-maxlen:]
else:
raise ValueError('Truncating type "%s" not understood' % truncating)
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
(trunc.shape[1:], idx, sample_shape))
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x