#!/usr/bin/python3
#
# example how to use the json-seq rendering
#

import json
import sys

import tqdm


class TqdmProgressRenderer:
    BAR_FMT = "{desc} ({n_fmt}/{total_fmt}): {percentage:3.0f}%|{bar}|{elapsed}"

    def __init__(self, inf, outf):
        self._pbar = None
        self._sub_pbar = None
        self._inf = inf
        self._outf = outf
        self._last_done = 0
        self._last_sub_done = 0

    def _read_json_seq_rec(self):
        # *sigh* we really should be using a proper json-seq reader here
        while True:
            line = self._inf.readline()
            if not line:
                return None
            try:
                payload = json.loads(line.strip("\x1e"))
            except json.JSONDecodeError:
                self.warn(f"WARN: invalid json: {line}")
                continue
            return payload

    def warn(self, warn):
        if self._pbar:
            self._pbar.write(warn)
        else:
            print(warn, file=self._outf)

    def _init_pbar(self, pbar_name, total, pos):
        pbar = getattr(self, pbar_name, None)
        if pbar is not None:
            return
        pbar = tqdm.tqdm(total=total, position=pos, bar_format=self.BAR_FMT)
        setattr(self, pbar_name, pbar)

    def render(self):
        while True:
            js = self._read_json_seq_rec()
            if js is None:
                return

            # main progress
            main_progress = js.get("progress", {})
            total = main_progress.get("total", 0)
            self._init_pbar("_pbar", total, pos=0)

            ctx = js["context"]
            pipeline_name = ctx.get("pipeline", {}).get("name")
            if pipeline_name:
                self._pbar.set_description(f"Pipeline {pipeline_name}")
            done = main_progress.get("done", 0)
            if self._last_done < done:
                self._pbar.update()
                self._last_done = done
                # reset sub-progress
                self._last_sub_done = 0
                self._sub_pbar = None

            # sub progress
            sub_progress = main_progress.get("progress")
            if sub_progress:
                total = sub_progress.get("total")
                self._init_pbar("_sub_pbar", total, pos=1)
                stage_name = ctx.get("pipeline", {}).get("stage", {}).get("name")
                if stage_name:
                    self._sub_pbar.set_description(f"Stage {stage_name}")
                sub_done = sub_progress.get("done", 0)
                if self._last_sub_done < sub_done:
                    self._sub_pbar.update()
                    self._last_sub_done = sub_done

            # (naively) handle messages (could decorate with origin)
            self._pbar.write(js.get("message", "").strip())


if __name__ == "__main__":
    prg = TqdmProgressRenderer(sys.stdin, sys.stdout)
    prg.render()
