debeir.core.callbacks

Callbacks for before after running. E.g. before is for setup after is for evaluation/serialization etc

  1"""
  2Callbacks for before after running.
  3E.g. before is for setup
  4after is for evaluation/serialization etc
  5"""
  6
  7import abc
  8import os
  9import tempfile
 10import uuid
 11from typing import List
 12
 13import loguru
 14from debeir.datasets.factory import query_factory
 15from debeir.evaluation.evaluator import Evaluator
 16from debeir.core.config import GenericConfig, NIRConfig
 17from debeir.core.pipeline import Pipeline
 18
 19
 20class Callback:
 21    def __init__(self):
 22        self.pipeline = None
 23
 24    @abc.abstractmethod
 25    def before(self, pipeline: Pipeline):
 26        pass
 27
 28    @abc.abstractmethod
 29    def after(self, results: List):
 30        pass
 31
 32
 33class SerializationCallback(Callback):
 34    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
 35        super().__init__()
 36        self.config = config
 37        self.nir_config = nir_config
 38        self.output_file = None
 39        self.query_cls = query_factory[self.config.query_fn]
 40
 41    def before(self, pipeline: Pipeline):
 42        """
 43        Check if output file exists
 44
 45        :return:
 46            Output file path
 47        """
 48
 49        self.pipeline = Pipeline
 50
 51        output_file = self.config.output_file
 52        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
 53
 54        if output_file is None:
 55            os.makedirs(name=output_dir, exist_ok=True)
 56            output_file = os.path.join(output_dir, str(uuid.uuid4()))
 57
 58            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
 59
 60        else:
 61            output_file = os.path.join(output_dir, output_file)
 62
 63        if os.path.exists(output_file):
 64            if not self.config.overwrite_output_if_exists:
 65                raise RuntimeError("Directory exists and isn't explicitly overwritten "
 66                                   "in config with overwrite_output_if_exists=True")
 67
 68            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
 69            open(output_file, "w+").close()
 70
 71        pipeline.output_file = output_file
 72        self.output_file = output_file
 73
 74    def after(self, results: List):
 75        """
 76        Serialize results to self.output_file in a TREC-style format
 77        :param topic_num: Topic number to serialize
 78        :param res: Raw elasticsearch result
 79        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
 80        """
 81
 82        self._after(results,
 83                    output_file=self.output_file,
 84                    run_name=self.config.run_name)
 85
 86    @classmethod
 87    def _after(self, results: List, output_file, run_name=None):
 88        if run_name is None:
 89            run_name = "NO_RUN_NAME"
 90
 91        with open(output_file, "a+t") as writer:
 92            for doc in results:
 93                line = f"{doc.topic_num}\t" \
 94                       f"Q0\t" \
 95                       f"{doc.doc_id}\t" \
 96                       f"{doc.scores['rank']}\t" \
 97                       f"{doc.score}\t" \
 98                       f"{run_name}\n"
 99
100                writer.write(line)
101
102
103class EvaluationCallback(Callback):
104    def __init__(self, evaluator: Evaluator, config):
105        super().__init__()
106        self.evaluator = evaluator
107        self.config = config
108        self.parsed_run = None
109
110    def before(self, pipeline: Pipeline):
111        self.pipeline = Pipeline
112
113    def after(self, results: List, id_field="id"):
114        if self.pipeline.output_file is None:
115            directory_name = tempfile.mkdtemp()
116            fn = str(uuid.uuid4())
117
118            fp = os.path.join(directory_name, fn)
119
120            query = query_factory[self.config.query_fn]
121            query.id_field = id_field
122
123            SerializationCallback._after(results,
124                                         output_file=fp,
125                                         run_name=self.config.run_name)
126
127            self.pipeline.output_file = fp
128
129        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
130                                                  disable_cache=True)
131        self.parsed_run = parsed_run
132
133        return self.parsed_run
class Callback:
21class Callback:
22    def __init__(self):
23        self.pipeline = None
24
25    @abc.abstractmethod
26    def before(self, pipeline: Pipeline):
27        pass
28
29    @abc.abstractmethod
30    def after(self, results: List):
31        pass
Callback()
22    def __init__(self):
23        self.pipeline = None
@abc.abstractmethod
def before(self, pipeline: debeir.core.pipeline.Pipeline):
25    @abc.abstractmethod
26    def before(self, pipeline: Pipeline):
27        pass
@abc.abstractmethod
def after(self, results: List):
29    @abc.abstractmethod
30    def after(self, results: List):
31        pass
class SerializationCallback(Callback):
 34class SerializationCallback(Callback):
 35    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
 36        super().__init__()
 37        self.config = config
 38        self.nir_config = nir_config
 39        self.output_file = None
 40        self.query_cls = query_factory[self.config.query_fn]
 41
 42    def before(self, pipeline: Pipeline):
 43        """
 44        Check if output file exists
 45
 46        :return:
 47            Output file path
 48        """
 49
 50        self.pipeline = Pipeline
 51
 52        output_file = self.config.output_file
 53        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
 54
 55        if output_file is None:
 56            os.makedirs(name=output_dir, exist_ok=True)
 57            output_file = os.path.join(output_dir, str(uuid.uuid4()))
 58
 59            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
 60
 61        else:
 62            output_file = os.path.join(output_dir, output_file)
 63
 64        if os.path.exists(output_file):
 65            if not self.config.overwrite_output_if_exists:
 66                raise RuntimeError("Directory exists and isn't explicitly overwritten "
 67                                   "in config with overwrite_output_if_exists=True")
 68
 69            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
 70            open(output_file, "w+").close()
 71
 72        pipeline.output_file = output_file
 73        self.output_file = output_file
 74
 75    def after(self, results: List):
 76        """
 77        Serialize results to self.output_file in a TREC-style format
 78        :param topic_num: Topic number to serialize
 79        :param res: Raw elasticsearch result
 80        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
 81        """
 82
 83        self._after(results,
 84                    output_file=self.output_file,
 85                    run_name=self.config.run_name)
 86
 87    @classmethod
 88    def _after(self, results: List, output_file, run_name=None):
 89        if run_name is None:
 90            run_name = "NO_RUN_NAME"
 91
 92        with open(output_file, "a+t") as writer:
 93            for doc in results:
 94                line = f"{doc.topic_num}\t" \
 95                       f"Q0\t" \
 96                       f"{doc.doc_id}\t" \
 97                       f"{doc.scores['rank']}\t" \
 98                       f"{doc.score}\t" \
 99                       f"{run_name}\n"
100
101                writer.write(line)
SerializationCallback( config: debeir.core.config.GenericConfig, nir_config: debeir.core.config.NIRConfig)
35    def __init__(self, config: GenericConfig, nir_config: NIRConfig):
36        super().__init__()
37        self.config = config
38        self.nir_config = nir_config
39        self.output_file = None
40        self.query_cls = query_factory[self.config.query_fn]
def before(self, pipeline: debeir.core.pipeline.Pipeline):
42    def before(self, pipeline: Pipeline):
43        """
44        Check if output file exists
45
46        :return:
47            Output file path
48        """
49
50        self.pipeline = Pipeline
51
52        output_file = self.config.output_file
53        output_dir = os.path.join(self.nir_config.output_directory, self.config.index)
54
55        if output_file is None:
56            os.makedirs(name=output_dir, exist_ok=True)
57            output_file = os.path.join(output_dir, str(uuid.uuid4()))
58
59            loguru.logger.info(f"Output file not specified, writing to: {output_file}")
60
61        else:
62            output_file = os.path.join(output_dir, output_file)
63
64        if os.path.exists(output_file):
65            if not self.config.overwrite_output_if_exists:
66                raise RuntimeError("Directory exists and isn't explicitly overwritten "
67                                   "in config with overwrite_output_if_exists=True")
68
69            loguru.logger.info(f"Output file exists: {output_file}. Overwriting...")
70            open(output_file, "w+").close()
71
72        pipeline.output_file = output_file
73        self.output_file = output_file

Check if output file exists

Returns
Output file path
def after(self, results: List):
75    def after(self, results: List):
76        """
77        Serialize results to self.output_file in a TREC-style format
78        :param topic_num: Topic number to serialize
79        :param res: Raw elasticsearch result
80        :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
81        """
82
83        self._after(results,
84                    output_file=self.output_file,
85                    run_name=self.config.run_name)

Serialize results to self.output_file in a TREC-style format

Parameters
  • topic_num: Topic number to serialize
  • res: Raw elasticsearch result
  • run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
class EvaluationCallback(Callback):
104class EvaluationCallback(Callback):
105    def __init__(self, evaluator: Evaluator, config):
106        super().__init__()
107        self.evaluator = evaluator
108        self.config = config
109        self.parsed_run = None
110
111    def before(self, pipeline: Pipeline):
112        self.pipeline = Pipeline
113
114    def after(self, results: List, id_field="id"):
115        if self.pipeline.output_file is None:
116            directory_name = tempfile.mkdtemp()
117            fn = str(uuid.uuid4())
118
119            fp = os.path.join(directory_name, fn)
120
121            query = query_factory[self.config.query_fn]
122            query.id_field = id_field
123
124            SerializationCallback._after(results,
125                                         output_file=fp,
126                                         run_name=self.config.run_name)
127
128            self.pipeline.output_file = fp
129
130        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
131                                                  disable_cache=True)
132        self.parsed_run = parsed_run
133
134        return self.parsed_run
EvaluationCallback(evaluator: debeir.evaluation.evaluator.Evaluator, config)
105    def __init__(self, evaluator: Evaluator, config):
106        super().__init__()
107        self.evaluator = evaluator
108        self.config = config
109        self.parsed_run = None
def before(self, pipeline: debeir.core.pipeline.Pipeline):
111    def before(self, pipeline: Pipeline):
112        self.pipeline = Pipeline
def after(self, results: List, id_field='id'):
114    def after(self, results: List, id_field="id"):
115        if self.pipeline.output_file is None:
116            directory_name = tempfile.mkdtemp()
117            fn = str(uuid.uuid4())
118
119            fp = os.path.join(directory_name, fn)
120
121            query = query_factory[self.config.query_fn]
122            query.id_field = id_field
123
124            SerializationCallback._after(results,
125                                         output_file=fp,
126                                         run_name=self.config.run_name)
127
128            self.pipeline.output_file = fp
129
130        parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file,
131                                                  disable_cache=True)
132        self.parsed_run = parsed_run
133
134        return self.parsed_run