from transformers import pipeline
from langchain.graphs import Neo4jGraph
from langchain.chains import GraphCypherQAChain
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from neo4j import GraphDatabase
import re
import textwrap
import json
URI = "neo4j+s://08398cf2.databases.neo4j.io"
USERNAME = "neo4j"
PASSWORD = "fVDb3eipO3R4HrPyPhaZLYSI4jpQRN60OUUBeBP-Eao"
import textwrap
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

################################################################################
# 1) LOCAL KNOWLEDGE GRAPH
################################################################################

local_kg = {
    "diagnoses": {
        "Head and Neck Squamous Cell Carcinoma": {
            "inclusion_criteria": [
                "Age 18-70 years",
                "ECOG performance status 0 or 1",
                "no prior PD-1, PD-L1 or EGFR inhibition",
                "No immunotherapy"
            ],
            "study_groups": {
                "TPEx group": {
                    "regimen": {
                        "regimenId": "TPEx",
                        "includesAgents": [
                            "Docetaxel 75 mg/m^2 (day 1)",
                            "Cisplatin 75 mg/m^2 (day 1)",
                            "Cetuximab 400 mg/m^2 (cycle 1, day 1), then 250 mg/m^2 weekly"
                        ],
                        "cycleLength": "21 days",
                        "numberOfCycles": "Up to 4 cycles",
                        "requiresGCSF": True,
                        "maintenanceTherapy": True
                    }
                },
                "EXTREME group": {
                    "regimen": {
                        "regimenId": "EXTREME",
                        "includesAgents": [
                            "Fluorouracil 4000 mg/m^2 (days 1-4)",
                            "Cisplatin 100 mg/m^2 (day 1)",
                            "Cetuximab 400 mg/m^2 (cycle 1, day 1), then 250 mg/m^2 weekly"
                        ],
                        "cycleLength": "21 days",
                        "numberOfCycles": "Up to 6 cycles",
                        "requiresGCSF": False,
                        "maintenanceTherapy": True
                    }
                }
            }
        }
    }
}

################################################################################
# 2) MODEL LOADING (SEQUENCE-TO-SEQUENCE)
################################################################################

model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

################################################################################
# 3) PROFILE PROCESSING & LOCAL QUERY
################################################################################

def process_patient_profile(profile: str) -> dict:
    profile_lines = profile.split("\n")
    patient_data = {}

    for line in profile_lines:
        line = line.strip()
        if line.startswith("Age:"):
            age_value = line.split(":", 1)[1].strip()
            patient_data["age"] = int(age_value) if age_value.isdigit() else None
        elif line.startswith("Gender:"):
            patient_data["gender"] = line.split(":", 1)[1].strip()
        elif line.startswith("Diagnosis:"):
            patient_data["diagnosis"] = line.split(":", 1)[1].strip()
        elif line.startswith("ECOGPerformanceStatus:"):
            ecog_value = line.split(":", 1)[1].strip()
            patient_data["ecog"] = int(ecog_value) if ecog_value.isdigit() else None
        elif line.startswith("TreatmentHistory:"):
            treatment_history = line.split(":", 1)[1].strip().lower()
            patient_data["treatment_history"] = "no prior" in treatment_history or "never received" in treatment_history

    has_required_fields = (
        patient_data.get("ecog") is not None and
        patient_data.get("age") is not None and
        patient_data.get("diagnosis") is not None
    )

    if has_required_fields:
        patient_data["eligible"] = (
            (patient_data["ecog"] < 2) and
            (18 <= patient_data["age"] <= 70) and
            (patient_data["diagnosis"] in local_kg["diagnoses"])
        )
    else:
        patient_data["eligible"] = False

    return patient_data


def run_local_query(profile_data: dict) -> list:
    results = []
    diagnosis = profile_data.get("diagnosis")
    if not profile_data.get("eligible"):
        return results

    diag_info = local_kg["diagnoses"].get(diagnosis)
    criteria = diag_info["inclusion_criteria"]
    required_criteria = [
        "Age 18-70 years",
        "ECOG performance status 0 or 1",
        "no prior PD-1, PD-L1 or EGFR inhibition",
        "No immunotherapy"
    ]

    if all(req in criteria for req in required_criteria):
        for group, data in diag_info["study_groups"].items():
            results.append(data["regimen"])

    return results

################################################################################
# 4) LLM-BASED RESPONSE
################################################################################

def generate_llm_response(query_results: list) -> str:
    if not query_results:
        return "No results found or the patient is ineligible for treatment recommendations."

    result_str = "\n".join(
        f"- Regimen ID: {r['regimenId']}, Agents: {r['includesAgents']}, Cycle Length: {r['cycleLength']}, "
        f"Cycles: {r['numberOfCycles']}, GCSF Required: {'Yes' if r['requiresGCSF'] else 'No'}, "
        f"Maintenance Therapy: {'Yes' if r['maintenanceTherapy'] else 'No'}"
        for r in query_results
    )

    prompt = f"""
    You are a clinical data expert. Based on the following structured query results, write a natural language response
    summarizing the recommended treatment regimens for a patient:

    Results:
    {result_str}

    Response:
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(inputs, max_length=200, num_beams=5, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

################################################################################
# 5) MAIN PIPELINE
################################################################################

def main():
    patient_profile = """
    Age: 58
    Gender: Female
    Diagnosis: Head and Neck Squamous Cell Carcinoma
    ECOGPerformanceStatus: 1
    TreatmentHistory: Oral Tongue Tumor (Stage II): Partial glossectomy; Adjuvant radiation (54 Gy) ended 14 months ago; No chemotherapy; Never received Cetuximab or IO therapy
    """

    profile_data = process_patient_profile(patient_profile)
    formatted_profile = textwrap.fill(
    json.dumps(profile_data, indent=2), width=100)
    print("Processed Patient Profile:\n")
    print(formatted_profile)


    results = run_local_query(profile_data)
    # print("\nStructured Query Results:")
    # for r in results:
    #     print(r)

    response = generate_llm_response(results)
    formatted_response = textwrap.fill(response, width=100)
    print("\nHere is the treatment recommendation:\n")
    print(formatted_response)

if __name__ == "__main__":
    main()
Processed Patient Profile:

{   "age": 58,   "gender": "Female",   "diagnosis": "Head and Neck Squamous Cell Carcinoma",
"ecog": 1,   "treatment_history": true,   "eligible": true }

Here is the treatment recommendation:

['Docetaxel 75 mg/m2 (day 1)', 'Cisplatin 75 mg/m2 (day 1)', 'Cetuximab 400 mg/m2 (day 1), then 250
mg/m2 weekly']