-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed
Description
This method is bound to the _calibrate method working properly.
Emgraph/emgraph/models/EmbeddingModel.py
Lines 2137 to 2171 in 3926ad7
| def _predict_proba(self, X): | |
| """Predicts probabilities using the Platt scaling model (after calibration). | |
| Model must be calibrated beforehand with the ``calibrate`` method. | |
| :param X: Numpy array of triples to be evaluated. | |
| :type X: ndarray, shape [n, 3] | |
| :return: Probability of each triple to be true according to the Platt scaling calibration. | |
| :rtype: ndarray, shape [n, 3] | |
| """ | |
| if not self.is_calibrated: | |
| msg = "Model has not been calibrated. Please call `model.calibrate(...)` before predicting probabilities." | |
| logger.error(msg) | |
| raise RuntimeError(msg) | |
| # tf.reset_default_graph() | |
| self._load_model_from_trained_params() | |
| w = tf.Variable(self.calibration_parameters[0], dtype=tf.float32, trainable=False) | |
| b = tf.Variable(self.calibration_parameters[1], dtype=tf.float32, trainable=False) | |
| x_idx = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx) | |
| x_tf = tf.Variable(x_idx, dtype=tf.int32, trainable=False) | |
| e_s, e_p, e_o = self._lookup_embeddings(x_tf) | |
| scores = self._fn(e_s, e_p, e_o) | |
| logits = -(w * scores + b) | |
| probas = tf.sigmoid(logits) | |
| # with tf.Session(config=self.tf_config) as sess: | |
| # sess.run(tf.global_variables_initializer()) | |
| # return sess.run(probas) | |
| return probas |
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or requesthelp wantedExtra attention is neededExtra attention is needed