Coverage for sherlock/tests/test_transient_classifier.py: 93%
73 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-10-10 13:58 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-10-10 13:58 +0000
1from __future__ import print_function
2from builtins import str
3import os
4import unittest
5import shutil
6import yaml
7from sherlock.utKit import utKit
8from fundamentals import tools
9from os.path import expanduser
10home = expanduser("~")
12packageDirectory = utKit("").get_project_root()
13settingsFile = packageDirectory + "/test_settings.yaml"
15su = tools(
16 arguments={"settingsFile": settingsFile},
17 docString=__doc__,
18 logLevel="DEBUG",
19 options_first=False,
20 projectName=None,
21 defaultSettingsFile=False
22)
23arguments, settings, log, dbConn = su.setup()
25# SETUP PATHS TO COMMON DIRECTORIES FOR TEST DATA
26moduleDirectory = os.path.dirname(__file__)
27pathToInputDir = moduleDirectory + "/input/"
28pathToOutputDir = moduleDirectory + "/output/"
30try:
31 shutil.rmtree(pathToOutputDir)
32except:
33 pass
34# COPY INPUT TO OUTPUT DIR
35shutil.copytree(pathToInputDir, pathToOutputDir)
37# Recursively create missing directories
38if not os.path.exists(pathToOutputDir):
39 os.makedirs(pathToOutputDir)
41settings["database settings"]["static catalogues"] = settings[
42 "database settings"]["static catalogues2"]
44# SETUP ALL DATABASE CONNECTIONS
45from sherlock import database
46db = database(
47 log=log,
48 settings=settings
49)
50dbConns, dbVersions = db.connect()
51transientsDbConn = dbConns["transients"]
52cataloguesDbConn = dbConns["catalogues"]
54from fundamentals.mysql import directory_script_runner
55directory_script_runner(
56 log=log,
57 pathToScriptDirectory=pathToInputDir.replace(
58 "/input", "/resources") + "/transient_database",
59 dbConn=transientsDbConn
60)
62class test_transient_classifier(unittest.TestCase):
64 def test_transient_update_classified_annotations_function(self):
66 from sherlock import transient_classifier
67 this = transient_classifier(
68 log=log,
69 settings=settings,
70 update=True
71 )
72 # this.update_peak_magnitudes()
73 this.update_classification_annotations_and_summaries()
75 def test_transient_classifier_function(self):
77 from sherlock import transient_classifier
78 this = transient_classifier(
79 log=log,
80 settings=settings,
81 update=True,
82 updateNed=False,
83 oneRun=True
84 )
85 this.classify()
87 def test_transient_classifier_single_source_function(self):
89 from sherlock import transient_classifier
90 this = transient_classifier(
91 log=log,
92 settings=settings,
93 ra="08:57:57.19",
94 dec="+43:25:44.1",
95 name="PS17gx",
96 updateNed=False,
97 verbose=0
98 )
99 classifications, crossmatches = this.classify()
101 def test_get_transient_metadata_from_database_list(self):
103 from sherlock import transient_classifier
104 classifier = transient_classifier(
105 log=log,
106 settings=settings,
107 updateNed=False
108 )
109 transientsMetadataList = classifier._get_transient_metadata_from_database_list()
110 # classifier._update_ned_stream(
111 # transientsMetadataList=transientsMetadataList
112 # )
114 def test_full_classifier(self):
116 from sherlock import transient_classifier
117 classifier = transient_classifier(
118 log=log,
119 settings=settings,
120 verbose=2,
121 update=True,
122 updateNed=False,
123 updatePeakMags=True
124 )
125 classifier.classify()
127 def test_classification_annotations(self):
129 from sherlock import database
130 db = database(
131 log=log,
132 settings=settings
133 )
134 dbConns, dbVersions = db.connect()
135 transientsDbConn = dbConns["transients"]
136 cataloguesDbConn = dbConns["catalogues"]
138 from sherlock.commonutils import get_crossmatch_catalogues_column_map
139 colMaps = get_crossmatch_catalogues_column_map(
140 log=log,
141 dbConn=cataloguesDbConn
142 )
144 from sherlock import transient_classifier
145 classifier = transient_classifier(
146 log=log,
147 settings=settings
148 )
149 classifier.classification_annotations()
151 def test_transient_classifier_function_exception(self):
153 from sherlock import transient_classifier
154 try:
155 this = transient_classifier(
156 log=log,
157 settings=settings,
158 fakeKey="break the code"
159 )
160 this.get()
161 assert False
162 except Exception as e:
163 assert True
164 print(str(e))
166 # x-class-to-test-named-worker-function