
from browser import document, window, aio
import sys
from io import StringIO
import traceback
from urllib.parse import unquote, parse_qs
import json
import base64
import datetime
from pt_urlopen_patch import open as patched_open  # optional, but nice to fail fast if missing


API_URL = "https://njl9djgkmg.execute-api.us-east-1.amazonaws.com/prod"

code = document["code"]
gutter = document["gutter"]
output_pre = document["output_pre"]
stdin_field = document["stdin_field"]
stdin_button = document["stdin_submit"]
tab_spinner = document["tabWidth"]





################################
#   MAX CODE TEXT AREA HEIGHT  #
################################

editor_container = document.select(".editor-container")[0]
MAX_EDITOR_HEIGHT_PX = 800

def adjust_editor_height():
    code.style.height = "auto"
    gutter.style.height = "auto"

    h_code = code.scrollHeight
    h_gutter = gutter.scrollHeight
    desired = max(h_code, h_gutter)

    if desired > MAX_EDITOR_HEIGHT_PX:
        desired = MAX_EDITOR_HEIGHT_PX

    editor_container.style.height = f"{desired}px"
    code.style.height = "100%"
    gutter.style.height = "100%"

########################
#   TAB WIDTH CHANGE   #
########################

def on_tabwidth_change(ev):
    global TAB_WIDTH
    try:
        v = int(tab_spinner.value)
    except:
        v = 2
    TAB_WIDTH = max(2, min(8, v))
    tab_spinner.value = TAB_WIDTH

tab_spinner.bind("input", on_tabwidth_change)
tab_spinner.bind("change", on_tabwidth_change)

########################
#   LIVE STD OUT       #
########################

class LiveStdout:
    def write(self, s):
        if not s:
            return
        output_pre.text += s
        output_pre.scrollTop = output_pre.scrollHeight

    def flush(self):
        pass

LIVE_STDOUT = LiveStdout()

###########################
#   ASYNC INPUT FUNCTION  #
###########################

async def input_async(prompt: str = "") -> str:
    if prompt:
        LIVE_STDOUT.write(prompt)

    fut = aio.Future()

    def on_submit(ev):
        value = stdin_field.value
        stdin_field.value = ""
        stdin_button.unbind("click", on_submit)
        fut.set_result(value)

    stdin_button.bind("click", on_submit)
    result = await fut
    LIVE_STDOUT.write(result + "\n")
    return result

########################
#   ASK AI ABOUT CODE  #
########################

def ask_ai_about_code(ev):
    src = code.value
    if not src.strip():
        output_pre.text = "No code to analyze."
        return

    output_pre.text = "Asking AI for feedback...\n"

    payload = {"code": src}
    opts = {
        "method": "POST",
        "headers": {
            "Content-Type": "application/json"
        },
        "body": window.JSON.stringify(payload)
    }

    def handle_text(raw):
        try:
            data = json.loads(raw)
        except Exception as e:
            output_pre.text = (
                "Got response from backend, but could not parse JSON.\n\n"
                f"Raw body:\n{raw}\n\n"
                f"Parsing error: {e}"
            )
            return

        if isinstance(data, dict):
            reply = data.get("reply", "(no 'reply' field in JSON)")
        else:
            reply = f"Unexpected JSON structure:\n{data!r}"

        output_pre.text = reply

    def handle_response(resp):
        status = resp.status
        ok = resp.ok

        def process_raw(raw):
            if not ok:
                output_pre.text = (
                    f"Backend error {status}.\n\n"
                    f"Raw response body:\n{raw}"
                )
            else:
                handle_text(raw)

        return resp.text().then(process_raw)

    def handle_error(err):
        output_pre.text = f"Network or fetch error:\n{err}"

    window.fetch(API_URL, opts).then(handle_response).catch(handle_error)

document["ask_ai"].bind("click", ask_ai_about_code)

########################
#   HASH PARAM PARSER  #
########################

def parse_params():
    params = {}

    search = window.location.search
    if search.startswith("?") and len(search) > 1:
        qs = parse_qs(search[1:])
        for k, vlist in qs.items():
            if vlist:
                params[k.lower()] = vlist[0]

    frag = window.location.hash
    if frag.startswith("#") and len(frag) > 1:
        frag_body = frag[1:]
        for part in frag_body.split("&"):
            if "=" in part:
                k, v = part.split("=", 1)
                params[k.lower()] = v

    return params

def is_paste_block_enabled():
    params = parse_params()
    val = params.get("paste", "").lower()
    if val in ("on", "true", "1", "yes", "allow"):
        return False
    elif val in ("off", "false", "0", "no", "block"):
        return True
    return False

PASTE_BLOCK_ENABLED = is_paste_block_enabled()


################################
#   LIGHT OR DARK THEME        #
################################

def apply_theme(theme: str):
    """Apply 'dark' or 'light' theme on page load."""
    body = document.select("body")[0]
    cls = body.classList

    theme = (theme or "").lower().strip()

    if theme == "dark":
        cls.add("dark-mode")
        theme_button.text = "Light mode"     # clicking will go back to light
    else:
        # Default to light
        cls.remove("dark-mode")
        theme_button.text = "Dark mode"



theme_button = document["toggle_theme"]
# Parse ?theme=dark or ?theme=light
query = window.location.search[1:]
params = parse_qs(query) if query else {}
theme_param = params.get("theme", [""])[0]

apply_theme(theme_param)


def toggle_theme(ev):
    body = document.select("body")[0]
    cls = body.classList

    if "dark-mode" in cls:
        cls.remove("dark-mode")
        theme_button.text = "Dark mode"
    else:
        cls.add("dark-mode")
        theme_button.text = "Light mode"

theme_button.bind("click", toggle_theme)

########################
#   LINE NUMBERS       #
########################

def update_gutter(ev=None):
    text = code.value
    line_count = text.count("\n") + 1
    gutter.text = "\n".join(str(i + 1) for i in range(line_count))

def sync_scroll(ev):
    gutter.scrollTop = code.scrollTop

code.bind("input", update_gutter)
code.bind("scroll", sync_scroll)

###############################
#   HIDE DIRTY SHARE URL      #
###############################

share_dirty = False

def invalidate_share_url(ev=None):
    global share_dirty
    share_dirty = True
    try:
        document["share_url_container"].attrs["hidden"] = True
        document["share_url_debug"].value = ""
    except KeyError:
        pass

code.bind("input", invalidate_share_url)

###############################
#   LOAD CODE FROM URL HASH   #
###############################

def decode_b64_urlsafe(s: str) -> str:
    try:
        s = unquote(s)
        padding = '=' * (-len(s) % 4)
        return base64.urlsafe_b64decode((s + padding).encode('ascii')).decode('utf-8')
    except Exception:
        return s

def to_urlsafe_b64(s: str) -> str:
    b = s.encode("utf-8")
    encoded = base64.urlsafe_b64encode(b).decode("ascii")
    return encoded.rstrip("=")

def _update_param_string(param_str: str, key: str, new_value: str) -> str:
    if not param_str:
        return f"{key}={new_value}"

    parts = param_str.split("&")
    updated = []
    found = False
    for part in parts:
        if not part:
            continue
        if "=" in part:
            k, _ = part.split("=", 1)
        else:
            k = part
        if k == key:
            updated.append(f"{key}={new_value}")
            found = True
        else:
            updated.append(part)
    if not found:
        updated.append(f"{key}={new_value}")
    return "&".join(updated)

def _has_key_in_param_string(param_str: str, key: str) -> bool:
    if not param_str:
        return False
    for part in param_str.split("&"):
        if not part:
            continue
        k = part.split("=", 1)[0]
        if k == key:
            return True
    return False

def strip_header_and_extract_tag(src: str):
    lines = src.splitlines()
    header_tag = None

    if len(lines) >= 2 and lines[0].lstrip().startswith("# tag:") and lines[1].lstrip().startswith("# shared:"):
        tag_line = lines[0].lstrip()
        header_tag = tag_line[len("# tag:"):].strip()
        lines = lines[2:]

    return "\n".join(lines), header_tag

def build_share_url_for_code(code_b64: str) -> str:
    loc = window.location
    origin = loc.origin
    pathname = loc.pathname
    search = loc.search or ""
    frag = loc.hash or ""

    search_body = search[1:] if search.startswith("?") else search
    frag_body = frag[1:] if frag.startswith("#") else frag

    has_code_in_query = _has_key_in_param_string(search_body, "code")
    has_code_in_hash = _has_key_in_param_string(frag_body, "code")

    if has_code_in_hash:
        new_frag_body = _update_param_string(frag_body, "code", code_b64)
        new_search_body = search_body
    elif has_code_in_query:
        new_search_body = _update_param_string(search_body, "code", code_b64)
        new_frag_body = frag_body
    else:
        if frag_body:
            new_frag_body = _update_param_string(frag_body, "code", code_b64)
            new_search_body = search_body
        else:
            new_search_body = _update_param_string(search_body, "code", code_b64)
            new_frag_body = frag_body

    new_url = origin + pathname
    if new_search_body:
        new_url += "?" + new_search_body
    if new_frag_body:
        new_url += "#" + new_frag_body
    return new_url

def load_code_from_url_hash():
    params = parse_params()
    raw = params.get("code")
    if raw is None:
        return

    # Try new compressed format first
    try:
        text = window.decodeCodeFromUrl(raw)
    except Exception:
        # Fall back to legacy: plain URL-safe base64 of UTF-8
        text = decode_b64_urlsafe(raw)

    code.value = text
    update_gutter()
    adjust_editor_height()


##################
#   RUN BUTTON   #
##################

def run_code_sync(src):
    old_stdout = sys.stdout
    buffer = StringIO()
    sys.stdout = buffer

    try:
        env = {
            "__name__": "__main__",
            "__builtins__": __builtins__,
            "open": patched_open,   # <-- inject patched open here
        }
        exec(src, env, env)

        stdout_text = buffer.getvalue()
        if not stdout_text.strip():
            stdout_text = "(no output)"

        result = "STDOUT:\n" + stdout_text
        output_pre.text = result

    except Exception as e:
        stdout_text = buffer.getvalue()
        if not stdout_text.strip():
            stdout_text = "(no output)"

        tb_full = ""
        try:
            tb_full = traceback.format_exc()
        except Exception:
            tb_full = ""

        tb = f"{type(e).__name__}: {e}"
        user_tb_lines = []

        if isinstance(tb_full, str) and tb_full.strip():
            tb_lines = tb_full.split("\n")

            started = False
            for line in tb_lines:
                if 'File "<string>"' in line:
                    started = True
                if started:
                    user_tb_lines.append(line)

            if user_tb_lines:
                cleaned = []
                skip_next_source_line = False
                for line in user_tb_lines:
                    if 'File "<string>"' in line and "line -1" in line:
                        prefix = line.split("line")[0].rstrip().rstrip(",")
                        cleaned.append(prefix)
                        skip_next_source_line = True
                        continue

                    if skip_next_source_line:
                        skip_next_source_line = False
                        stripped = line.lstrip()
                        if (line.startswith(" ")
                            and not stripped.startswith("File ")
                            and "Error:" not in line
                            and "Exception:" not in line):
                            continue

                    cleaned.append(line)

                tb = "\n".join(cleaned)
            else:
                tb = tb_full

        line_num = None
        lines = src.split("\n")

        for line in user_tb_lines:
            if "line " in line:
                try:
                    part = line.split("line ")[-1].strip().split()[0]
                    if part.isdigit() and int(part) > 0:
                        line_num = int(part)
                        break
                except Exception:
                    pass

        if line_num is not None and 1 <= line_num <= len(lines):
            max_line = line_num
        else:
            max_line = len(lines)

        formatted = []
        for i in range(1, max_line + 1):
            line_text = lines[i - 1]
            if line_num is not None and i == line_num:
                formatted.append(f">> {i:>3} | {line_text}")
            else:
                formatted.append(f"   {i:>3} | {line_text}")

        result = (
            "STDOUT:\n" + stdout_text +
            "\n\nERROR:\n" + tb +
            "\n\nCODE:\n" + "\n".join(formatted)
        )
        output_pre.text = result

    finally:
        sys.stdout = old_stdout


def run_code_async(src):
    output_pre.text = ""
    old_stdout = sys.stdout
    sys.stdout = LIVE_STDOUT

    env = {
        "__name__": "__main__",
        "__builtins__": __builtins__,
        "input_async": input_async,
        "aio": aio,
        "open": patched_open,   # <-- inject patched open here too
    }

    try:
        exec(src, env, env)
    except Exception:
        traceback.print_exc(file=sys.stdout)
    finally:
        sys.stdout = old_stdout




def run_code(ev):
    output_pre.text = ""  # clear on new run
    src = code.value

    # Detect async mode on the *raw* student code
    is_async_mode = ("await input_async" in src) or ("aio.run(" in src)

    # Prepend a hidden import so student code always sees our patched open
    prologue = "from pt_urlopen_patch import open as open\n"
    full_src = prologue + src

    if is_async_mode:
        run_code_async(full_src)
    else:
        run_code_sync(full_src)


def clear_output(ev):
    output_pre.text = ""

########################
#   SHARE CURRENT URL  #
########################

def share_current_url(ev):
    params = parse_params()
    url_tag = params.get("tag", "")

    body, header_tag = strip_header_and_extract_tag(code.value)

    tag_value = url_tag if url_tag else (header_tag or "")

    if len(tag_value) >= 2 and tag_value[0] == tag_value[-1] == '"':
        tag_value = tag_value[1:-1]

    if tag_value:
        tag_comment = f"# tag: {tag_value}"
    else:
        tag_comment = "# tag:"
    ts = datetime.datetime.now().isoformat(timespec="seconds")
    shared_comment = f"# shared: {ts}"

    if body:
        shared_src = tag_comment + "\n" + shared_comment + "\n" + body
    else:
        shared_src = tag_comment + "\n" + shared_comment + "\n"

    code_b64 = window.encodeCodeForUrl(shared_src)
    url = build_share_url_for_code(code_b64)

    try:
        container = document["share_url_container"]
        if "hidden" in container.attrs:
            del container.attrs["hidden"]
        document["share_url_debug"].value = url
    except KeyError:
        pass

    try:
        window.navigator.clipboard.writeText(url)
    except Exception as e:
        try:
            window.console.log(f"Clipboard write failed: {e}")
        except Exception:
            pass

    return url

def refresh_with_share(ev):
    url = share_current_url(ev)
    window.location.href = url

document["run1"].bind("click", run_code)
document["run2"].bind("click", run_code)
document["clear_output"].bind("click", clear_output)
document["share"].bind("click", share_current_url)
document["refresh_share"].bind("click", refresh_with_share)

########################
#   WRAP OUTPUT ONLY   #
########################

output_wrap_enabled = False

def toggle_output_wrap(ev):
    global output_wrap_enabled
    output_wrap_enabled = not output_wrap_enabled

    if output_wrap_enabled:
        output_pre.style.whiteSpace = "pre-wrap"
        output_pre.style.overflowX = "hidden"
        output_pre.style.setProperty("overflow-wrap", "anywhere")
        output_pre.style.setProperty("word-break", "break-word")
        document["wrap_toggle"].text = "Wrap on"
    else:
        output_pre.style.whiteSpace = "pre"
        output_pre.style.overflowX = "auto"
        output_pre.style.removeProperty("overflow-wrap")
        output_pre.style.removeProperty("word-break")
        document["wrap_toggle"].text = "Wrap off"

document["wrap_toggle"].bind("click", toggle_output_wrap)

##################################
#   BLOCK PASTE / DROP / KEYS    #
##################################

def block(ev):
    ev.preventDefault()
    ev.stopPropagation()

def on_paste(ev):
    block(ev)
    window.alert("Pasting is disabled in this editor.")

def on_drop(ev):
    block(ev)

def on_dragover(ev):
    block(ev)

def on_dragenter(ev):
    block(ev)

def on_context_menu(ev):
    block(ev)

TAB_WIDTH = 2

def on_keydown(ev):
    key = ev.key
    ctrl = ev.ctrlKey
    meta = ev.metaKey
    shift = ev.shiftKey

    if key == "Enter" and (meta or ctrl):
        ev.preventDefault()
        run_code(ev)
        return

    if PASTE_BLOCK_ENABLED:
        if (key.lower() == "v" and (ctrl or meta)) or (key == "Insert" and shift):
            block(ev)
            window.alert("Paste via keyboard is disabled.")

if PASTE_BLOCK_ENABLED:
    code.bind("paste", on_paste)
    code.bind("drop", on_drop)
    code.bind("dragover", on_dragover)
    code.bind("dragenter", on_dragenter)
    code.bind("contextmenu", on_context_menu)

code.bind("keydown", on_keydown)

########################
#   INITIAL SETUP      #
########################

load_code_from_url_hash()
update_gutter()