debeir.core.pipeline

  1import abc
  2from typing import List
  3
  4import debeir
  5from debeir.core.config import Config, GenericConfig
  6from debeir.core.executor import GenericElasticsearchExecutor
  7from debeir.core.results import Results
  8from debeir.datasets.factory import factory_fn, get_nir_config
  9from debeir.engines.client import Client
 10from loguru import logger
 11
 12
 13class Pipeline:
 14    pipeline_structure = ["parser", "query", "engine", "evaluator"]
 15    cannot_disable = ["parser", "query", "engine"]
 16    callbacks: List['debeir.core.callbacks.Callback']
 17    output_file = None
 18
 19    def __init__(self, engine: GenericElasticsearchExecutor,
 20                 engine_name: str,
 21                 metrics_config,
 22                 engine_config,
 23                 nir_config,
 24                 run_config: Config,
 25                 callbacks=None):
 26
 27        self.engine = engine
 28        self.engine_name = engine_name
 29        self.run_config = run_config
 30        self.metrics_config = metrics_config
 31        self.engine_config = engine_config
 32        self.nir_config = nir_config
 33        self.output_file = None
 34        self.disable = {}
 35
 36        if callbacks is None:
 37            self.callbacks = []
 38        else:
 39            self.callbacks = callbacks
 40
 41    @classmethod
 42    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
 43        query_cls, config, parser, executor_cls = factory_fn(config_fp)
 44
 45        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
 46                                                                          engine=engine,
 47                                                                          ignore_errors=False)
 48
 49        client = Client.build_from_config(engine, search_engine_config)
 50        topics = parser._get_topics(config.topics_path)
 51
 52        query = query_cls(topics=topics, query_type=config.query_type, config=config)
 53
 54        executor = executor_cls.build_from_config(
 55            topics,
 56            query,
 57            client.get_client(engine),
 58            config,
 59            nir_config
 60        )
 61
 62        return cls(
 63            executor,
 64            engine,
 65            metrics_config,
 66            search_engine_config,
 67            nir_config,
 68            config
 69        )
 70
 71    def disable(self, parts: list):
 72        for part in parts:
 73            if part in self.pipeline_structure and part not in self.cannot_disable:
 74                self.disable[part] = True
 75            else:
 76                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
 77
 78    @abc.abstractmethod
 79    async def run_pipeline(self, *args,
 80                           **kwargs):
 81        raise NotImplementedError()
 82
 83
 84class NIRPipeline(Pipeline):
 85    run_config: GenericConfig
 86
 87    def __init__(self, *args, **kwargs):
 88        super().__init__(*args, **kwargs)
 89
 90    async def prehook(self):
 91        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
 92            logger.info(f"Running initial BM25 for query adjustment")
 93            await self.engine.run_automatic_adjustment()
 94
 95    async def run_engine(self, *args, **kwargs):
 96        # Run bm25 nir adjustment
 97        logger.info(f"Running {self.run_config.query_type} queries")
 98
 99        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
100
101    async def posthook(self, *args, **kwargs):
102        pass
103
104    async def run_pipeline(self, *args, return_results=False, **kwargs):
105        for cb in self.callbacks:
106            cb.before(self)
107
108        await self.prehook()
109        results = await self.run_engine(*args, **kwargs)
110        results = Results(results, self.engine.query, self.engine_name)
111
112        for cb in self.callbacks:
113            cb.after(results)
114
115        return results
116
117    def register_callback(self, cb):
118        self.callbacks.append(cb)
119
120
121class BM25Pipeline(NIRPipeline):
122    async def run_pipeline(self, *args, return_results=False, **kwargs):
123        for cb in self.callbacks:
124            cb.before(self)
125
126        results = await self.engine.run_all_queries(query_type="query",
127                                                    return_results=True)
128
129        results = Results(results, self.engine.query, self.engine_name)
130
131        for cb in self.callbacks:
132            cb.after(results)
133
134        return results
class Pipeline:
14class Pipeline:
15    pipeline_structure = ["parser", "query", "engine", "evaluator"]
16    cannot_disable = ["parser", "query", "engine"]
17    callbacks: List['debeir.core.callbacks.Callback']
18    output_file = None
19
20    def __init__(self, engine: GenericElasticsearchExecutor,
21                 engine_name: str,
22                 metrics_config,
23                 engine_config,
24                 nir_config,
25                 run_config: Config,
26                 callbacks=None):
27
28        self.engine = engine
29        self.engine_name = engine_name
30        self.run_config = run_config
31        self.metrics_config = metrics_config
32        self.engine_config = engine_config
33        self.nir_config = nir_config
34        self.output_file = None
35        self.disable = {}
36
37        if callbacks is None:
38            self.callbacks = []
39        else:
40            self.callbacks = callbacks
41
42    @classmethod
43    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
44        query_cls, config, parser, executor_cls = factory_fn(config_fp)
45
46        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
47                                                                          engine=engine,
48                                                                          ignore_errors=False)
49
50        client = Client.build_from_config(engine, search_engine_config)
51        topics = parser._get_topics(config.topics_path)
52
53        query = query_cls(topics=topics, query_type=config.query_type, config=config)
54
55        executor = executor_cls.build_from_config(
56            topics,
57            query,
58            client.get_client(engine),
59            config,
60            nir_config
61        )
62
63        return cls(
64            executor,
65            engine,
66            metrics_config,
67            search_engine_config,
68            nir_config,
69            config
70        )
71
72    def disable(self, parts: list):
73        for part in parts:
74            if part in self.pipeline_structure and part not in self.cannot_disable:
75                self.disable[part] = True
76            else:
77                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
78
79    @abc.abstractmethod
80    async def run_pipeline(self, *args,
81                           **kwargs):
82        raise NotImplementedError()
Pipeline( engine: debeir.core.executor.GenericElasticsearchExecutor, engine_name: str, metrics_config, engine_config, nir_config, run_config: debeir.core.config.Config, callbacks=None)
20    def __init__(self, engine: GenericElasticsearchExecutor,
21                 engine_name: str,
22                 metrics_config,
23                 engine_config,
24                 nir_config,
25                 run_config: Config,
26                 callbacks=None):
27
28        self.engine = engine
29        self.engine_name = engine_name
30        self.run_config = run_config
31        self.metrics_config = metrics_config
32        self.engine_config = engine_config
33        self.nir_config = nir_config
34        self.output_file = None
35        self.disable = {}
36
37        if callbacks is None:
38            self.callbacks = []
39        else:
40            self.callbacks = callbacks
def disable(self, parts: list):
72    def disable(self, parts: list):
73        for part in parts:
74            if part in self.pipeline_structure and part not in self.cannot_disable:
75                self.disable[part] = True
76            else:
77                logger.warning(f"Cannot disable {part} because it doesn't exist or is integral to the pipeline")
@classmethod
def build_from_config(cls, nir_config_fp, engine, config_fp) -> debeir.core.pipeline.Pipeline:
42    @classmethod
43    def build_from_config(cls, nir_config_fp, engine, config_fp) -> 'Pipeline':
44        query_cls, config, parser, executor_cls = factory_fn(config_fp)
45
46        nir_config, search_engine_config, metrics_config = get_nir_config(nir_config_fp,
47                                                                          engine=engine,
48                                                                          ignore_errors=False)
49
50        client = Client.build_from_config(engine, search_engine_config)
51        topics = parser._get_topics(config.topics_path)
52
53        query = query_cls(topics=topics, query_type=config.query_type, config=config)
54
55        executor = executor_cls.build_from_config(
56            topics,
57            query,
58            client.get_client(engine),
59            config,
60            nir_config
61        )
62
63        return cls(
64            executor,
65            engine,
66            metrics_config,
67            search_engine_config,
68            nir_config,
69            config
70        )
@abc.abstractmethod
async def run_pipeline(self, *args, **kwargs):
79    @abc.abstractmethod
80    async def run_pipeline(self, *args,
81                           **kwargs):
82        raise NotImplementedError()
class NIRPipeline(Pipeline):
 85class NIRPipeline(Pipeline):
 86    run_config: GenericConfig
 87
 88    def __init__(self, *args, **kwargs):
 89        super().__init__(*args, **kwargs)
 90
 91    async def prehook(self):
 92        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
 93            logger.info(f"Running initial BM25 for query adjustment")
 94            await self.engine.run_automatic_adjustment()
 95
 96    async def run_engine(self, *args, **kwargs):
 97        # Run bm25 nir adjustment
 98        logger.info(f"Running {self.run_config.query_type} queries")
 99
100        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
101
102    async def posthook(self, *args, **kwargs):
103        pass
104
105    async def run_pipeline(self, *args, return_results=False, **kwargs):
106        for cb in self.callbacks:
107            cb.before(self)
108
109        await self.prehook()
110        results = await self.run_engine(*args, **kwargs)
111        results = Results(results, self.engine.query, self.engine_name)
112
113        for cb in self.callbacks:
114            cb.after(results)
115
116        return results
117
118    def register_callback(self, cb):
119        self.callbacks.append(cb)
NIRPipeline(*args, **kwargs)
88    def __init__(self, *args, **kwargs):
89        super().__init__(*args, **kwargs)
async def prehook(self):
91    async def prehook(self):
92        if self.run_config.automatic or self.run_config.norm_weight == "automatic":
93            logger.info(f"Running initial BM25 for query adjustment")
94            await self.engine.run_automatic_adjustment()
async def run_engine(self, *args, **kwargs):
 96    async def run_engine(self, *args, **kwargs):
 97        # Run bm25 nir adjustment
 98        logger.info(f"Running {self.run_config.query_type} queries")
 99
100        return await self.engine.run_all_queries(*args, return_results=True, **kwargs)
async def posthook(self, *args, **kwargs):
102    async def posthook(self, *args, **kwargs):
103        pass
async def run_pipeline(self, *args, return_results=False, **kwargs):
105    async def run_pipeline(self, *args, return_results=False, **kwargs):
106        for cb in self.callbacks:
107            cb.before(self)
108
109        await self.prehook()
110        results = await self.run_engine(*args, **kwargs)
111        results = Results(results, self.engine.query, self.engine_name)
112
113        for cb in self.callbacks:
114            cb.after(results)
115
116        return results
def register_callback(self, cb):
118    def register_callback(self, cb):
119        self.callbacks.append(cb)
Inherited Members
Pipeline
disable
build_from_config
class BM25Pipeline(NIRPipeline):
122class BM25Pipeline(NIRPipeline):
123    async def run_pipeline(self, *args, return_results=False, **kwargs):
124        for cb in self.callbacks:
125            cb.before(self)
126
127        results = await self.engine.run_all_queries(query_type="query",
128                                                    return_results=True)
129
130        results = Results(results, self.engine.query, self.engine_name)
131
132        for cb in self.callbacks:
133            cb.after(results)
134
135        return results
async def run_pipeline(self, *args, return_results=False, **kwargs):
123    async def run_pipeline(self, *args, return_results=False, **kwargs):
124        for cb in self.callbacks:
125            cb.before(self)
126
127        results = await self.engine.run_all_queries(query_type="query",
128                                                    return_results=True)
129
130        results = Results(results, self.engine.query, self.engine_name)
131
132        for cb in self.callbacks:
133            cb.after(results)
134
135        return results