Files
protocol-bicorder/analysis/bicorder_query.py
Nathan Schneider fa527bd1f1 Initial analysis
2025-11-21 19:34:33 -07:00

231 lines
7.3 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Query LLM for individual gradient values and update CSV.
This script generates prompts for each gradient, queries the LLM conversation,
and updates the CSV with the returned values.
"""
import csv
import json
import sys
import argparse
import subprocess
import re
from pathlib import Path
def load_bicorder_config(bicorder_path):
"""Load and parse the bicorder.json configuration file."""
with open(bicorder_path, 'r') as f:
return json.load(f)
def extract_gradients(bicorder_data):
"""Extract all gradients from the diagnostic sets."""
gradients = []
for diagnostic_set in bicorder_data['diagnostic']:
set_name = diagnostic_set['set_name']
for gradient in diagnostic_set['gradients']:
col_name = f"{set_name}_{gradient['term_left']}_vs_{gradient['term_right']}"
gradients.append({
'column_name': col_name,
'set_name': set_name,
'term_left': gradient['term_left'],
'term_left_description': gradient['term_left_description'],
'term_right': gradient['term_right'],
'term_right_description': gradient['term_right_description']
})
return gradients
def get_protocol_by_row(csv_path, row_number):
"""Get protocol data from CSV by row number (1-indexed)."""
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for i, row in enumerate(reader, start=1):
if i == row_number:
return {
'descriptor': row.get('Descriptor', '').strip(),
'description': row.get('Description', '').strip()
}
return None
def generate_gradient_prompt(protocol_descriptor, protocol_description, gradient):
"""Generate a prompt for a single gradient evaluation."""
return f"""Analyze this protocol: "{protocol_descriptor}"
Description: {protocol_description}
Evaluate the protocol on this gradient:
**{gradient['term_left']}** (1) vs **{gradient['term_right']}** (9)
- **{gradient['term_left']}**: {gradient['term_left_description']}
- **{gradient['term_right']}**: {gradient['term_right_description']}
Provide a rating from 1 to 9, where:
- 1 = strongly {gradient['term_left']}
- 5 = neutral/balanced/not applicable
- 9 = strongly {gradient['term_right']}
Respond with ONLY the number (1-9), optionally followed by a brief explanation.
"""
def query_llm(prompt, model=None):
"""Send prompt to llm CLI and get response."""
cmd = ['llm']
if model:
cmd.extend(['-m', model])
try:
result = subprocess.run(
cmd,
input=prompt,
text=True,
capture_output=True,
check=True
)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f" Error calling llm: {e.stderr}", file=sys.stderr)
return None
def extract_value(llm_response):
"""Extract numeric value (1-9) from LLM response."""
# Look for a number 1-9 at the start of the response
match = re.search(r'^(\d)', llm_response.strip())
if match:
value = int(match.group(1))
if 1 <= value <= 9:
return value
return None
def update_csv_cell(csv_path, row_number, column_name, value):
"""Update a specific cell in the CSV."""
# Read all rows
rows = []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
fieldnames = reader.fieldnames
for row in reader:
rows.append(row)
# Update the specific cell
if row_number <= len(rows):
rows[row_number - 1][column_name] = str(value)
# Write back
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
return True
return False
def main():
parser = argparse.ArgumentParser(
description='Query LLM for gradient values and update CSV',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example usage:
# Query all gradients for protocol in row 1
python3 bicorder_query.py analysis_output.csv 1
# Query specific model
python3 bicorder_query.py analysis_output.csv 1 -m mistral
# Dry run (show prompts without calling LLM)
python3 bicorder_query.py analysis_output.csv 1 --dry-run
"""
)
parser.add_argument('csv_path', help='CSV file to update')
parser.add_argument('row_number', type=int, help='Row number to analyze (1-indexed)')
parser.add_argument('-b', '--bicorder',
default='../bicorder.json',
help='Path to bicorder.json (default: ../bicorder.json)')
parser.add_argument('-m', '--model', help='LLM model to use')
parser.add_argument('--dry-run', action='store_true',
help='Show prompts without calling LLM or updating CSV')
args = parser.parse_args()
# Validate files exist
if not Path(args.csv_path).exists():
print(f"Error: CSV file '{args.csv_path}' not found", file=sys.stderr)
sys.exit(1)
if not Path(args.bicorder).exists():
print(f"Error: Bicorder config '{args.bicorder}' not found", file=sys.stderr)
sys.exit(1)
# Load protocol data
protocol = get_protocol_by_row(args.csv_path, args.row_number)
if protocol is None:
print(f"Error: Row {args.row_number} not found in CSV", file=sys.stderr)
sys.exit(1)
# Load bicorder config
bicorder_data = load_bicorder_config(args.bicorder)
gradients = extract_gradients(bicorder_data)
if args.dry_run:
print(f"DRY RUN: Row {args.row_number}, {len(gradients)} gradients")
print(f"Protocol: {protocol['descriptor']}\n")
else:
print(f"Protocol: {protocol['descriptor']}")
print(f"Loaded {len(gradients)} gradients, starting queries...")
# Process each gradient
for i, gradient in enumerate(gradients, 1):
gradient_short = gradient['column_name'].replace('_', ' ')
if not args.dry_run:
print(f"[{i}/{len(gradients)}] Querying: {gradient_short}...", flush=True)
# Generate prompt (including protocol context)
prompt = generate_gradient_prompt(
protocol['descriptor'],
protocol['description'],
gradient
)
if args.dry_run:
print(f"[{i}/{len(gradients)}] {gradient_short}")
print(f"Prompt:\n{prompt}\n")
continue
# Query LLM (new chat each time)
response = query_llm(prompt, args.model)
if response is None:
print(f"[{i}/{len(gradients)}] {gradient_short}: FAILED")
continue
# Extract value
value = extract_value(response)
if value is None:
print(f"[{i}/{len(gradients)}] {gradient_short}: WARNING - no valid value")
continue
# Update CSV
if update_csv_cell(args.csv_path, args.row_number, gradient['column_name'], value):
print(f"[{i}/{len(gradients)}] {gradient_short}: {value}")
else:
print(f"[{i}/{len(gradients)}] {gradient_short}: ERROR updating CSV")
if not args.dry_run:
print(f"\n✓ CSV updated: {args.csv_path}")
if __name__ == '__main__':
main()