import textwrap
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
################################################################################
# 1) LOCAL KNOWLEDGE GRAPH
################################################################################
local_kg = {
"diagnoses": {
"Head and Neck Squamous Cell Carcinoma": {
"inclusion_criteria": [
"Age 18-70 years",
"ECOG performance status 0 or 1",
"no prior PD-1, PD-L1 or EGFR inhibition",
"No immunotherapy"
],
"study_groups": {
"TPEx group": {
"regimen": {
"regimenId": "TPEx",
"includesAgents": [
"Docetaxel 75 mg/m^2 (day 1)",
"Cisplatin 75 mg/m^2 (day 1)",
"Cetuximab 400 mg/m^2 (cycle 1, day 1), then 250 mg/m^2 weekly"
],
"cycleLength": "21 days",
"numberOfCycles": "Up to 4 cycles",
"requiresGCSF": True,
"maintenanceTherapy": True
}
},
"EXTREME group": {
"regimen": {
"regimenId": "EXTREME",
"includesAgents": [
"Fluorouracil 4000 mg/m^2 (days 1-4)",
"Cisplatin 100 mg/m^2 (day 1)",
"Cetuximab 400 mg/m^2 (cycle 1, day 1), then 250 mg/m^2 weekly"
],
"cycleLength": "21 days",
"numberOfCycles": "Up to 6 cycles",
"requiresGCSF": False,
"maintenanceTherapy": True
}
}
}
}
}
}
################################################################################
# 2) MODEL LOADING (SEQUENCE-TO-SEQUENCE)
################################################################################
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
################################################################################
# 3) PROFILE PROCESSING & LOCAL QUERY
################################################################################
def process_patient_profile(profile: str) -> dict:
profile_lines = profile.split("\n")
patient_data = {}
for line in profile_lines:
line = line.strip()
if line.startswith("Age:"):
age_value = line.split(":", 1)[1].strip()
patient_data["age"] = int(age_value) if age_value.isdigit() else None
elif line.startswith("Gender:"):
patient_data["gender"] = line.split(":", 1)[1].strip()
elif line.startswith("Diagnosis:"):
patient_data["diagnosis"] = line.split(":", 1)[1].strip()
elif line.startswith("ECOGPerformanceStatus:"):
ecog_value = line.split(":", 1)[1].strip()
patient_data["ecog"] = int(ecog_value) if ecog_value.isdigit() else None
elif line.startswith("TreatmentHistory:"):
treatment_history = line.split(":", 1)[1].strip().lower()
patient_data["treatment_history"] = "no prior" in treatment_history or "never received" in treatment_history
has_required_fields = (
patient_data.get("ecog") is not None and
patient_data.get("age") is not None and
patient_data.get("diagnosis") is not None
)
if has_required_fields:
patient_data["eligible"] = (
(patient_data["ecog"] < 2) and
(18 <= patient_data["age"] <= 70) and
(patient_data["diagnosis"] in local_kg["diagnoses"])
)
else:
patient_data["eligible"] = False
return patient_data
def run_local_query(profile_data: dict) -> list:
results = []
diagnosis = profile_data.get("diagnosis")
if not profile_data.get("eligible"):
return results
diag_info = local_kg["diagnoses"].get(diagnosis)
criteria = diag_info["inclusion_criteria"]
required_criteria = [
"Age 18-70 years",
"ECOG performance status 0 or 1",
"no prior PD-1, PD-L1 or EGFR inhibition",
"No immunotherapy"
]
if all(req in criteria for req in required_criteria):
for group, data in diag_info["study_groups"].items():
results.append(data["regimen"])
return results
################################################################################
# 4) LLM-BASED RESPONSE
################################################################################
def generate_llm_response(query_results: list) -> str:
if not query_results:
return "No results found or the patient is ineligible for treatment recommendations."
result_str = "\n".join(
f"- Regimen ID: {r['regimenId']}, Agents: {r['includesAgents']}, Cycle Length: {r['cycleLength']}, "
f"Cycles: {r['numberOfCycles']}, GCSF Required: {'Yes' if r['requiresGCSF'] else 'No'}, "
f"Maintenance Therapy: {'Yes' if r['maintenanceTherapy'] else 'No'}"
for r in query_results
)
prompt = f"""
You are a clinical data expert. Based on the following structured query results, write a natural language response
summarizing the recommended treatment regimens for a patient:
Results:
{result_str}
Response:
"""
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
outputs = model.generate(inputs, max_length=200, num_beams=5, early_stopping=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
################################################################################
# 5) MAIN PIPELINE
################################################################################
def main():
patient_profile = """
Age: 58
Gender: Female
Diagnosis: Head and Neck Squamous Cell Carcinoma
ECOGPerformanceStatus: 1
TreatmentHistory: Oral Tongue Tumor (Stage II): Partial glossectomy; Adjuvant radiation (54 Gy) ended 14 months ago; No chemotherapy; Never received Cetuximab or IO therapy
"""
profile_data = process_patient_profile(patient_profile)
formatted_profile = textwrap.fill(
json.dumps(profile_data, indent=2), width=100)
print("Processed Patient Profile:\n")
print(formatted_profile)
results = run_local_query(profile_data)
# print("\nStructured Query Results:")
# for r in results:
# print(r)
response = generate_llm_response(results)
formatted_response = textwrap.fill(response, width=100)
print("\nHere is the treatment recommendation:\n")
print(formatted_response)
if __name__ == "__main__":
main()
Processed Patient Profile:
{ "age": 58, "gender": "Female", "diagnosis": "Head and Neck Squamous Cell Carcinoma",
"ecog": 1, "treatment_history": true, "eligible": true }
Here is the treatment recommendation:
['Docetaxel 75 mg/m2 (day 1)', 'Cisplatin 75 mg/m2 (day 1)', 'Cetuximab 400 mg/m2 (day 1), then 250
mg/m2 weekly']