#!/usr/bin/env python3
"""
Gemini Vision OCR test — pump photos
Run: python3 test_pump_ocr.py
"""
import base64, json, urllib.request, urllib.error, ssl, os, sys

GEMINI_KEY = os.environ.get('GEMINI_API_KEY') or ''
GEMINI_URL = 'https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent'

TESTS = [
    {
        'file':     'Pump photo 1.jpeg',
        'type':     'fuel_pump_diesel',
        'expected': {'amount_rs': 2390.15, 'quantity_litres': 24.31, 'rate_per_litre': 98.32, 'density': 827.5},
        'prompt':   (
            "This is a petrol/diesel fuel pump dispenser display at an Indian petrol station.\n"
            "The display shows: total amount in rupees (top large number),\n"
            "quantity in litres (second number), rate per litre.\n"
            "Extract carefully. Decimal points matter.\n"
            "Return ONLY raw JSON, no markdown, no explanation:\n"
            '{"amount_rs":2390.15,"quantity_litres":24.31,"rate_per_litre":98.32,"density":null,"confidence":"high"}\n'
            "If density is not shown, set it null."
        ),
    },
    {
        'file':     'Pump photo 2.jpeg',
        'type':     'fuel_pump_cng',
        'expected': {'amount_rs': 810.66, 'quantity_kg': 9.186, 'rate_per_kg': 88.25},
        'prompt':   (
            "This is a CNG (Compressed Natural Gas) fuel pump display at an Indian CNG station.\n"
            "The display shows: total amount in rupees, quantity in kg, and rate per kg.\n"
            "Extract carefully. Decimal points matter.\n"
            "Return ONLY raw JSON, no markdown, no explanation:\n"
            '{"amount_rs":810.66,"quantity_kg":9.186,"rate_per_kg":88.25,"confidence":"high"}'
        ),
    },
]

def call_gemini(b64, prompt):
    import re
    payload = json.dumps({
        'contents': [{'parts': [
            {'inline_data': {'mime_type': 'image/jpeg', 'data': b64}},
            {'text': prompt}
        ]}],
        'generationConfig': {'temperature': 0.1, 'maxOutputTokens': 512}
    }).encode()
    req = urllib.request.Request(
        f"{GEMINI_URL}?key={GEMINI_KEY}", data=payload,
        headers={'Content-Type': 'application/json'}, method='POST')
    ctx = ssl.create_default_context()
    with urllib.request.urlopen(req, timeout=30, context=ctx) as r:
        result = json.loads(r.read())
    raw = (result.get('candidates', [{}])[0].get('content', {})
                 .get('parts', [{}])[0].get('text', '')).strip()
    cleaned = re.sub(r'```json|```', '', raw).strip()
    return json.loads(cleaned), raw

def compare(got, expected):
    ok, issues = True, []
    for k, ev in expected.items():
        gv = got.get(k)
        if gv is None:
            issues.append(f"  MISSING {k} (expected {ev})")
            ok = False
        elif abs(float(gv) - float(ev)) > 0.01:
            issues.append(f"  MISMATCH {k}: got {gv}, expected {ev}")
            ok = False
    return ok, issues

if not GEMINI_KEY:
    print("ERROR: set GEMINI_API_KEY env var")
    sys.exit(1)

script_dir = os.path.dirname(os.path.abspath(__file__))
passed = 0
for t in TESTS:
    path = os.path.join(script_dir, t['file'])
    with open(path, 'rb') as f:
        b64 = base64.b64encode(f.read()).decode()
    print(f"\n{'='*60}")
    print(f"TEST: {t['type']} — {t['file']}")
    print(f"Expected: {t['expected']}")
    try:
        data, raw = call_gemini(b64, t['prompt'])
        print(f"Got:      {data}")
        ok, issues = compare(data, t['expected'])
        if ok:
            print(f"RESULT:   ✅ PASS (confidence={data.get('confidence','?')})")
            passed += 1
        else:
            print(f"RESULT:   ⚠️  PARTIAL")
            for i in issues: print(i)
    except Exception as e:
        print(f"RESULT:   ❌ ERROR — {e}")

print(f"\n{'='*60}")
print(f"SUMMARY: {passed}/{len(TESTS)} passed")
