Set up analysis scripts
This commit is contained in:
230
analysis/bicorder_query.py
Normal file
230
analysis/bicorder_query.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Query LLM for individual gradient values and update CSV.
|
||||
|
||||
This script generates prompts for each gradient, queries the LLM conversation,
|
||||
and updates the CSV with the returned values.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_bicorder_config(bicorder_path):
|
||||
"""Load and parse the bicorder.json configuration file."""
|
||||
with open(bicorder_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def extract_gradients(bicorder_data):
|
||||
"""Extract all gradients from the diagnostic sets."""
|
||||
gradients = []
|
||||
for diagnostic_set in bicorder_data['diagnostic']:
|
||||
set_name = diagnostic_set['set_name']
|
||||
|
||||
for gradient in diagnostic_set['gradients']:
|
||||
col_name = f"{set_name}_{gradient['term_left']}_vs_{gradient['term_right']}"
|
||||
gradients.append({
|
||||
'column_name': col_name,
|
||||
'set_name': set_name,
|
||||
'term_left': gradient['term_left'],
|
||||
'term_left_description': gradient['term_left_description'],
|
||||
'term_right': gradient['term_right'],
|
||||
'term_right_description': gradient['term_right_description']
|
||||
})
|
||||
|
||||
return gradients
|
||||
|
||||
|
||||
def get_protocol_by_row(csv_path, row_number):
|
||||
"""Get protocol data from CSV by row number (1-indexed)."""
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for i, row in enumerate(reader, start=1):
|
||||
if i == row_number:
|
||||
return {
|
||||
'descriptor': row.get('Descriptor', '').strip(),
|
||||
'description': row.get('Description', '').strip()
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def generate_gradient_prompt(protocol_descriptor, protocol_description, gradient):
|
||||
"""Generate a prompt for a single gradient evaluation."""
|
||||
return f"""Analyze this protocol: "{protocol_descriptor}"
|
||||
|
||||
Description: {protocol_description}
|
||||
|
||||
Evaluate the protocol on this gradient:
|
||||
|
||||
**{gradient['term_left']}** (1) vs **{gradient['term_right']}** (9)
|
||||
|
||||
- **{gradient['term_left']}**: {gradient['term_left_description']}
|
||||
- **{gradient['term_right']}**: {gradient['term_right_description']}
|
||||
|
||||
Provide a rating from 1 to 9, where:
|
||||
- 1 = strongly {gradient['term_left']}
|
||||
- 5 = neutral/balanced
|
||||
- 9 = strongly {gradient['term_right']}
|
||||
|
||||
Respond with ONLY the number (1-9), optionally followed by a brief explanation.
|
||||
"""
|
||||
|
||||
|
||||
def query_llm(prompt, model=None):
|
||||
"""Send prompt to llm CLI and get response."""
|
||||
cmd = ['llm']
|
||||
if model:
|
||||
cmd.extend(['-m', model])
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
input=prompt,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=True
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f" Error calling llm: {e.stderr}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def extract_value(llm_response):
|
||||
"""Extract numeric value (1-9) from LLM response."""
|
||||
# Look for a number 1-9 at the start of the response
|
||||
match = re.search(r'^(\d)', llm_response.strip())
|
||||
if match:
|
||||
value = int(match.group(1))
|
||||
if 1 <= value <= 9:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def update_csv_cell(csv_path, row_number, column_name, value):
|
||||
"""Update a specific cell in the CSV."""
|
||||
# Read all rows
|
||||
rows = []
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
fieldnames = reader.fieldnames
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
|
||||
# Update the specific cell
|
||||
if row_number <= len(rows):
|
||||
rows[row_number - 1][column_name] = str(value)
|
||||
|
||||
# Write back
|
||||
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Query LLM for gradient values and update CSV',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Example usage:
|
||||
# Query all gradients for protocol in row 1
|
||||
python3 bicorder_query.py analysis_output.csv 1
|
||||
|
||||
# Query specific model
|
||||
python3 bicorder_query.py analysis_output.csv 1 -m mistral
|
||||
|
||||
# Dry run (show prompts without calling LLM)
|
||||
python3 bicorder_query.py analysis_output.csv 1 --dry-run
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('csv_path', help='CSV file to update')
|
||||
parser.add_argument('row_number', type=int, help='Row number to analyze (1-indexed)')
|
||||
parser.add_argument('-b', '--bicorder',
|
||||
default='../bicorder.json',
|
||||
help='Path to bicorder.json (default: ../bicorder.json)')
|
||||
parser.add_argument('-m', '--model', help='LLM model to use')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='Show prompts without calling LLM or updating CSV')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate files exist
|
||||
if not Path(args.csv_path).exists():
|
||||
print(f"Error: CSV file '{args.csv_path}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not Path(args.bicorder).exists():
|
||||
print(f"Error: Bicorder config '{args.bicorder}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Load protocol data
|
||||
protocol = get_protocol_by_row(args.csv_path, args.row_number)
|
||||
if protocol is None:
|
||||
print(f"Error: Row {args.row_number} not found in CSV", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Load bicorder config
|
||||
bicorder_data = load_bicorder_config(args.bicorder)
|
||||
gradients = extract_gradients(bicorder_data)
|
||||
|
||||
if args.dry_run:
|
||||
print(f"DRY RUN: Row {args.row_number}, {len(gradients)} gradients")
|
||||
print(f"Protocol: {protocol['descriptor']}\n")
|
||||
else:
|
||||
print(f"Protocol: {protocol['descriptor']}")
|
||||
print(f"Loaded {len(gradients)} gradients, starting queries...")
|
||||
|
||||
# Process each gradient
|
||||
for i, gradient in enumerate(gradients, 1):
|
||||
gradient_short = gradient['column_name'].replace('_', ' ')
|
||||
|
||||
if not args.dry_run:
|
||||
print(f"[{i}/{len(gradients)}] Querying: {gradient_short}...", flush=True)
|
||||
|
||||
# Generate prompt (including protocol context)
|
||||
prompt = generate_gradient_prompt(
|
||||
protocol['descriptor'],
|
||||
protocol['description'],
|
||||
gradient
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
print(f"[{i}/{len(gradients)}] {gradient_short}")
|
||||
print(f"Prompt:\n{prompt}\n")
|
||||
continue
|
||||
|
||||
# Query LLM (new chat each time)
|
||||
response = query_llm(prompt, args.model)
|
||||
|
||||
if response is None:
|
||||
print(f"[{i}/{len(gradients)}] {gradient_short}: FAILED")
|
||||
continue
|
||||
|
||||
# Extract value
|
||||
value = extract_value(response)
|
||||
if value is None:
|
||||
print(f"[{i}/{len(gradients)}] {gradient_short}: WARNING - no valid value")
|
||||
continue
|
||||
|
||||
# Update CSV
|
||||
if update_csv_cell(args.csv_path, args.row_number, gradient['column_name'], value):
|
||||
print(f"[{i}/{len(gradients)}] {gradient_short}: {value}")
|
||||
else:
|
||||
print(f"[{i}/{len(gradients)}] {gradient_short}: ERROR updating CSV")
|
||||
|
||||
if not args.dry_run:
|
||||
print(f"\n✓ CSV updated: {args.csv_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user