Set up analysis scripts
This commit is contained in:
175
analysis/bicorder_batch.py
Normal file
175
analysis/bicorder_batch.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch process all protocols in a CSV using the Bicorder framework.
|
||||
|
||||
This script orchestrates the entire analysis workflow:
|
||||
1. Creates output CSV with gradient columns
|
||||
2. For each protocol row:
|
||||
- Queries all 23 gradients (each in a new chat)
|
||||
- Updates CSV with results
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def count_csv_rows(csv_path):
|
||||
"""Count the number of data rows in a CSV file."""
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
return sum(1 for _ in reader)
|
||||
|
||||
|
||||
def run_bicorder_analyze(input_csv, output_csv, bicorder_path, analyst=None, standpoint=None):
|
||||
"""Run bicorder_analyze.py to create output CSV."""
|
||||
cmd = ['python3', 'bicorder_analyze.py', input_csv, '-o', output_csv, '-b', bicorder_path]
|
||||
|
||||
if analyst:
|
||||
cmd.extend(['-a', analyst])
|
||||
if standpoint:
|
||||
cmd.extend(['-s', standpoint])
|
||||
|
||||
print(f"Creating analysis CSV: {output_csv}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error creating CSV: {result.stderr}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
print(result.stdout)
|
||||
return True
|
||||
|
||||
|
||||
def query_gradients(output_csv, row_num, bicorder_path, model=None):
|
||||
"""Query all gradients for a protocol row."""
|
||||
cmd = ['python3', 'bicorder_query.py', output_csv, str(row_num),
|
||||
'-b', bicorder_path]
|
||||
|
||||
if model:
|
||||
cmd.extend(['-m', model])
|
||||
|
||||
print(f"Starting gradient queries...")
|
||||
|
||||
# Don't capture output - let it print in real-time for progress visibility
|
||||
result = subprocess.run(cmd)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error querying gradients", file=sys.stderr)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def process_protocol_row(input_csv, output_csv, row_num, total_rows, bicorder_path, model=None):
|
||||
"""Process a single protocol row through the complete workflow."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Row {row_num}/{total_rows}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Query all gradients (each gradient gets a new chat)
|
||||
if not query_gradients(output_csv, row_num, bicorder_path, model):
|
||||
print(f"[FAILED] Could not query gradients")
|
||||
return False
|
||||
|
||||
print(f"✓ Row {row_num} complete")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Batch process protocols through Bicorder analysis (each gradient uses a new chat)',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Example usage:
|
||||
# Process all protocols
|
||||
python3 bicorder_batch.py protocols_edited.csv -o analysis_output.csv
|
||||
|
||||
# Process specific rows
|
||||
python3 bicorder_batch.py protocols_edited.csv -o analysis_output.csv --start 1 --end 5
|
||||
|
||||
# With specific model
|
||||
python3 bicorder_batch.py protocols_edited.csv -o analysis_output.csv -m mistral
|
||||
|
||||
# With metadata
|
||||
python3 bicorder_batch.py protocols_edited.csv -o analysis_output.csv -a "Your Name" -s "Your standpoint"
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument('input_csv', help='Input CSV file with protocol data')
|
||||
parser.add_argument('-o', '--output', required=True, help='Output CSV file')
|
||||
parser.add_argument('-b', '--bicorder',
|
||||
default='../bicorder.json',
|
||||
help='Path to bicorder.json (default: ../bicorder.json)')
|
||||
parser.add_argument('-m', '--model', help='LLM model to use')
|
||||
parser.add_argument('-a', '--analyst', help='Analyst name')
|
||||
parser.add_argument('-s', '--standpoint', help='Analyst standpoint')
|
||||
parser.add_argument('--start', type=int, default=1,
|
||||
help='Start row number (1-indexed, default: 1)')
|
||||
parser.add_argument('--end', type=int,
|
||||
help='End row number (1-indexed, default: all rows)')
|
||||
parser.add_argument('--resume', action='store_true',
|
||||
help='Resume from existing output CSV (skip rows with values)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate input file exists
|
||||
if not Path(args.input_csv).exists():
|
||||
print(f"Error: Input file '{args.input_csv}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Validate bicorder.json exists
|
||||
if not Path(args.bicorder).exists():
|
||||
print(f"Error: Bicorder config '{args.bicorder}' not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Count rows in input CSV
|
||||
total_rows = count_csv_rows(args.input_csv)
|
||||
end_row = args.end if args.end else total_rows
|
||||
|
||||
if args.start > total_rows or end_row > total_rows:
|
||||
print(f"Error: Row range exceeds CSV size ({total_rows} rows)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Bicorder Batch Analysis")
|
||||
print(f"Input: {args.input_csv} ({total_rows} protocols)")
|
||||
print(f"Output: {args.output}")
|
||||
print(f"Processing rows: {args.start} to {end_row}")
|
||||
if args.model:
|
||||
print(f"Model: {args.model}")
|
||||
print()
|
||||
|
||||
# Step 1: Create output CSV (unless resuming)
|
||||
if not args.resume or not Path(args.output).exists():
|
||||
if not run_bicorder_analyze(args.input_csv, args.output, args.bicorder,
|
||||
args.analyst, args.standpoint):
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"Resuming from existing CSV: {args.output}")
|
||||
|
||||
# Step 2: Process each protocol row
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for row_num in range(args.start, end_row + 1):
|
||||
if process_protocol_row(args.input_csv, args.output, row_num, end_row,
|
||||
args.bicorder, args.model):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
print(f"[WARNING] Row {row_num} failed, continuing...")
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BATCH COMPLETE")
|
||||
print(f"{'='*60}")
|
||||
print(f"Successful: {success_count}")
|
||||
print(f"Failed: {fail_count}")
|
||||
print(f"Output: {args.output}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user