231 lines
7.3 KiB
Python
Executable File
231 lines
7.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Query LLM for individual gradient values and update CSV.
|
|
|
|
This script generates prompts for each gradient, queries the LLM conversation,
|
|
and updates the CSV with the returned values.
|
|
"""
|
|
|
|
import csv
|
|
import json
|
|
import sys
|
|
import argparse
|
|
import subprocess
|
|
import re
|
|
from pathlib import Path
|
|
|
|
|
|
def load_bicorder_config(bicorder_path):
|
|
"""Load and parse the bicorder.json configuration file."""
|
|
with open(bicorder_path, 'r') as f:
|
|
return json.load(f)
|
|
|
|
|
|
def extract_gradients(bicorder_data):
|
|
"""Extract all gradients from the diagnostic sets."""
|
|
gradients = []
|
|
for diagnostic_set in bicorder_data['diagnostic']:
|
|
set_name = diagnostic_set['set_name']
|
|
|
|
for gradient in diagnostic_set['gradients']:
|
|
col_name = f"{set_name}_{gradient['term_left']}_vs_{gradient['term_right']}"
|
|
gradients.append({
|
|
'column_name': col_name,
|
|
'set_name': set_name,
|
|
'term_left': gradient['term_left'],
|
|
'term_left_description': gradient['term_left_description'],
|
|
'term_right': gradient['term_right'],
|
|
'term_right_description': gradient['term_right_description']
|
|
})
|
|
|
|
return gradients
|
|
|
|
|
|
def get_protocol_by_row(csv_path, row_number):
|
|
"""Get protocol data from CSV by row number (1-indexed)."""
|
|
with open(csv_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.DictReader(f)
|
|
for i, row in enumerate(reader, start=1):
|
|
if i == row_number:
|
|
return {
|
|
'descriptor': row.get('Descriptor', '').strip(),
|
|
'description': row.get('Description', '').strip()
|
|
}
|
|
return None
|
|
|
|
|
|
def generate_gradient_prompt(protocol_descriptor, protocol_description, gradient):
|
|
"""Generate a prompt for a single gradient evaluation."""
|
|
return f"""Analyze this protocol: "{protocol_descriptor}"
|
|
|
|
Description: {protocol_description}
|
|
|
|
Evaluate the protocol on this gradient:
|
|
|
|
**{gradient['term_left']}** (1) vs **{gradient['term_right']}** (9)
|
|
|
|
- **{gradient['term_left']}**: {gradient['term_left_description']}
|
|
- **{gradient['term_right']}**: {gradient['term_right_description']}
|
|
|
|
Provide a rating from 1 to 9, where:
|
|
- 1 = strongly {gradient['term_left']}
|
|
- 5 = neutral/balanced/not applicable
|
|
- 9 = strongly {gradient['term_right']}
|
|
|
|
Respond with ONLY the number (1-9), optionally followed by a brief explanation.
|
|
"""
|
|
|
|
|
|
def query_llm(prompt, model=None):
|
|
"""Send prompt to llm CLI and get response."""
|
|
cmd = ['llm']
|
|
if model:
|
|
cmd.extend(['-m', model])
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
input=prompt,
|
|
text=True,
|
|
capture_output=True,
|
|
check=True
|
|
)
|
|
return result.stdout.strip()
|
|
except subprocess.CalledProcessError as e:
|
|
print(f" Error calling llm: {e.stderr}", file=sys.stderr)
|
|
return None
|
|
|
|
|
|
def extract_value(llm_response):
|
|
"""Extract numeric value (1-9) from LLM response."""
|
|
# Look for a number 1-9 at the start of the response
|
|
match = re.search(r'^(\d)', llm_response.strip())
|
|
if match:
|
|
value = int(match.group(1))
|
|
if 1 <= value <= 9:
|
|
return value
|
|
return None
|
|
|
|
|
|
def update_csv_cell(csv_path, row_number, column_name, value):
|
|
"""Update a specific cell in the CSV."""
|
|
# Read all rows
|
|
rows = []
|
|
with open(csv_path, 'r', encoding='utf-8') as f:
|
|
reader = csv.DictReader(f)
|
|
fieldnames = reader.fieldnames
|
|
for row in reader:
|
|
rows.append(row)
|
|
|
|
# Update the specific cell
|
|
if row_number <= len(rows):
|
|
rows[row_number - 1][column_name] = str(value)
|
|
|
|
# Write back
|
|
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
return True
|
|
return False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Query LLM for gradient values and update CSV',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Example usage:
|
|
# Query all gradients for protocol in row 1
|
|
python3 bicorder_query.py analysis_output.csv 1
|
|
|
|
# Query specific model
|
|
python3 bicorder_query.py analysis_output.csv 1 -m mistral
|
|
|
|
# Dry run (show prompts without calling LLM)
|
|
python3 bicorder_query.py analysis_output.csv 1 --dry-run
|
|
"""
|
|
)
|
|
|
|
parser.add_argument('csv_path', help='CSV file to update')
|
|
parser.add_argument('row_number', type=int, help='Row number to analyze (1-indexed)')
|
|
parser.add_argument('-b', '--bicorder',
|
|
default='../bicorder.json',
|
|
help='Path to bicorder.json (default: ../bicorder.json)')
|
|
parser.add_argument('-m', '--model', help='LLM model to use')
|
|
parser.add_argument('--dry-run', action='store_true',
|
|
help='Show prompts without calling LLM or updating CSV')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate files exist
|
|
if not Path(args.csv_path).exists():
|
|
print(f"Error: CSV file '{args.csv_path}' not found", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if not Path(args.bicorder).exists():
|
|
print(f"Error: Bicorder config '{args.bicorder}' not found", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Load protocol data
|
|
protocol = get_protocol_by_row(args.csv_path, args.row_number)
|
|
if protocol is None:
|
|
print(f"Error: Row {args.row_number} not found in CSV", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Load bicorder config
|
|
bicorder_data = load_bicorder_config(args.bicorder)
|
|
gradients = extract_gradients(bicorder_data)
|
|
|
|
if args.dry_run:
|
|
print(f"DRY RUN: Row {args.row_number}, {len(gradients)} gradients")
|
|
print(f"Protocol: {protocol['descriptor']}\n")
|
|
else:
|
|
print(f"Protocol: {protocol['descriptor']}")
|
|
print(f"Loaded {len(gradients)} gradients, starting queries...")
|
|
|
|
# Process each gradient
|
|
for i, gradient in enumerate(gradients, 1):
|
|
gradient_short = gradient['column_name'].replace('_', ' ')
|
|
|
|
if not args.dry_run:
|
|
print(f"[{i}/{len(gradients)}] Querying: {gradient_short}...", flush=True)
|
|
|
|
# Generate prompt (including protocol context)
|
|
prompt = generate_gradient_prompt(
|
|
protocol['descriptor'],
|
|
protocol['description'],
|
|
gradient
|
|
)
|
|
|
|
if args.dry_run:
|
|
print(f"[{i}/{len(gradients)}] {gradient_short}")
|
|
print(f"Prompt:\n{prompt}\n")
|
|
continue
|
|
|
|
# Query LLM (new chat each time)
|
|
response = query_llm(prompt, args.model)
|
|
|
|
if response is None:
|
|
print(f"[{i}/{len(gradients)}] {gradient_short}: FAILED")
|
|
continue
|
|
|
|
# Extract value
|
|
value = extract_value(response)
|
|
if value is None:
|
|
print(f"[{i}/{len(gradients)}] {gradient_short}: WARNING - no valid value")
|
|
continue
|
|
|
|
# Update CSV
|
|
if update_csv_cell(args.csv_path, args.row_number, gradient['column_name'], value):
|
|
print(f"[{i}/{len(gradients)}] {gradient_short}: {value}")
|
|
else:
|
|
print(f"[{i}/{len(gradients)}] {gradient_short}: ERROR updating CSV")
|
|
|
|
if not args.dry_run:
|
|
print(f"\n✓ CSV updated: {args.csv_path}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|