Source code for QhX.tests.test_parallel_solver

import os
import gc
import threading
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import unittest
from QhX.parallelization_solver import ParallelSolver
from QhX import DataManagerDynamical, process1_new_dyn

[docs] class TestParallelSolver(unittest.TestCase):
[docs] def setUp(self): print("Running setUp...") # Debugging print agn_dc_mapping = { 'column_mapping': {'flux': 'psMag', 'time': 'mjd', 'band': 'filter'}, 'group_by_key': 'objectId', 'filter_mapping': {0: 0, 1: 1, 2: 2, 3: 3} } self.data_manager = DataManagerDynamical( column_mapping=agn_dc_mapping['column_mapping'], group_by_key=agn_dc_mapping['group_by_key'], filter_mapping=agn_dc_mapping['filter_mapping'] ) self.synthetic_data = self.create_synthetic_data() self.synthetic_data_file = 'synthetic_test_data.parquet' self.synthetic_data.to_parquet(self.synthetic_data_file) self.data_manager.load_data(self.synthetic_data_file) self.data_manager.group_data() self.solver = ParallelSolver( delta_seconds=12.0, num_workers=2, data_manager=self.data_manager, log_time=True, log_files=False, save_results=True, process_function=process1_new_dyn, parallel_arithmetic=True, ntau=80, ngrid=100, provided_minfq=500, provided_maxfq=10, mode='dynamical' ) self.setids = ['1']
[docs] def create_synthetic_data(self): np.random.seed(42) object_id = '1' num_measurements = 50 mjd_values = np.linspace(50000, 50500, num=num_measurements) psMag_values = np.random.normal(loc=20.0, scale=0.5, size=num_measurements) psMagErr_values = np.random.uniform(0.02, 0.1, size=num_measurements) filter_values = np.tile([0, 1, 2, 3], int(num_measurements / 4) + 1)[:num_measurements] data = { 'objectId': [object_id] * num_measurements, 'mjd': mjd_values, 'psMag': psMag_values, 'psMagErr': psMagErr_values, 'filter': filter_values } return pd.DataFrame(data)
[docs] def test_parallel_solver_process_and_merge(self): print("Running test_parallel_solver_process_and_merge...") # Debugging print try: self.solver.process_ids(set_ids=self.setids, results_file='1-reslut.csv') print("Solver processed IDs successfully.") # Debugging print except Exception as e: self.fail(f"Error processing/saving data: {e}") if not os.path.exists('1-reslut.csv'): self.fail("Processed result file missing or cannot be read") # Read the processing result and check structure actual_df = pd.read_csv('1-reslut.csv') print("Actual DataFrame read successfully.") # Debugging print # Check that the DataFrame has the expected columns expected_columns = [ "ID", "Sampling_1", "Sampling_2", "Common period (Band1 & Band2)", "Upper error bound", "Lower error bound", "Significance", "Band1-Band2" ] self.assertListEqual(list(actual_df.columns), expected_columns) # Optional: Check if numerical values fall within expected ranges self.assertTrue((actual_df["Sampling_1"] > 0).all()) self.assertTrue((actual_df["Sampling_2"] > 0).all()) self.assertTrue((actual_df["Significance"].fillna(0) >= 0).all()) # Allow NaN, otherwise check non-negative # Print the result DataFrame for inspection print("\nContents of 1-reslut.csv:") print(actual_df.to_string(index=False)) # Print DataFrame without row indices
[docs] def tearDown(self): print("Cleaning up...") # Debugging print if hasattr(self.solver, 'executor') and self.solver.executor: try: self.solver.executor.shutdown(wait=True) print("Executor shutdown successfully.") # Debugging print except Exception as e: print(f"Error during executor shutdown: {e}") if os.path.isfile(self.synthetic_data_file): os.remove(self.synthetic_data_file) if os.path.isfile('1-reslut.csv'): os.remove('1-reslut.csv') gc.collect() for thread in threading.enumerate(): if thread.name != "MainThread": print(f"Thread {thread.name} is still active.") # Debugging print
if __name__ == '__main__': unittest.main()