import time
from sqlalchemy import func
from tqdm import tqdm


from disk_objectstore.database import Obj
from disk_objectstore import Container
from disk_objectstore.utils import PackedObjectReader, StreamDecompresser, get_hash
c = Container('test-newrepo-sdb/')
session = c._get_cached_session()

t = time.time() ; tmp = list(c.list_all_objects()); print(len(tmp)); print(time.time() - t)
#6714808
#22.632724285125732

# 2nd time:
#6714808
#25.884494304656982

t = time.time()
pack_ids = sorted(set(res[0] for res in session.query(Obj).with_entities(Obj.pack_id).distinct()))
print(time.time() - t)
print(pack_ids)

#1.2832145690917969
#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]

#1.256399154663086
#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]


session = c._get_cached_session()
t = time.time()
pack_obj_count = dict(session.query(Obj.pack_id, func.count(Obj.id)).group_by(Obj.pack_id).all())
print(time.time() - t)
print(pack_obj_count)

#2.5562264919281006
#{0: 205454, 1: 217076, 2: 210356, 3: 205630, 4: 195733, 5: 211017, 6: 199799, 7: 204072, 8: 193464, 9: 205595, 10: 187756, 11: 190113, 12: 184654, 13: 189648, 14: 181734, 15: 181360, 16: 191239, 17: 171055, 18: 174349, 19: 166539, 20: 165742, 21: 169346, 22: 168770, 23: 160722, 24: 162834, 25: 158756, 26: 157223, 27: 151478, 28: 152049, 29: 145876, 30: 141134, 31: 144465, 32: 142484, 33: 138356, 34: 134978, 35: 131797, 36: 131755, 37: 129656, 38: 126942, 39: 33802}

#2.6068077087402344
#{0: 205454, 1: 217076, 2: 210356, 3: 205630, 4: 195733, 5: 211017, 6: 199799, 7: 204072, 8: 193464, 9: 205595, 10: 187756, 11: 190113, 12: 184654, 13: 189648, 14: 181734, 15: 181360, 16: 191239, 17: 171055, 18: 174349, 19: 166539, 20: 165742, 21: 169346, 22: 168770, 23: 160722, 24: 162834, 25: 158756, 26: 157223, 27: 151478, 28: 152049, 29: 145876, 30: 141134, 31: 144465, 32: 142484, 33: 138356, 34: 134978, 35: 131797, 36: 131755, 37: 129656, 38: 126942, 39: 33802}


t = time.time()
cnt = 0
for hashkey, size in session.query(Obj).filter(Obj.pack_id == 0).order_by(Obj.offset).with_entities(Obj.hashkey, Obj.size):
    cnt += 1
print(cnt, time.time() - t)

#205454 1.962942123413086

#205454 1.8577516078948975


pack_id = 0
total = pack_obj_count[pack_id]
cnt = 0
t = time.time()
for hashkey, size in tqdm(
    session.query(Obj).filter(Obj.pack_id == 0).order_by(Obj.offset).with_entities(Obj.hashkey, Obj.size),
    total=total):
    cnt += 1
print(total, cnt, time.time() - t)
# Note that the first ~1.5s are to get the data from the DB
# However, it's a good idea to do it pack by pack so we don't get excessive amount of data
# since the amount of data per pack is capped

#100%|████████████████████████████████| 205454/205454 [00:02<00:00, 84400.61it/s]
#205454 205454 2.435641050338745

#100%|████████████████████████████████| 205454/205454 [00:02<00:00, 76979.75it/s]
#205454 205454 2.6701900959014893


# Equivalent syntax
pack_id = 0
total = pack_obj_count[pack_id]
cnt = 0
t = time.time()
for hashkey, size in tqdm(
    session.query(Obj.hashkey, Obj.size).filter(Obj.pack_id == 0).order_by(Obj.offset),
    total=total):
    cnt += 1
print(total, cnt, time.time() - t)



## Read data from disk - slow, has to reopen the file at every object
pack_id = 0
total = pack_obj_count[pack_id]
cnt = 0
t = time.time()
for hashkey, size in tqdm(
    session.query(Obj.hashkey, Obj.size).filter(Obj.pack_id == 0).order_by(Obj.offset),
    total=total):
    c.get_object_content(hashkey)
    cnt += 1
print(total, cnt, time.time() - t)

#100%|█████████████████████████████████| 205454/205454 [02:11<00:00, 1564.56it/s]
#205454 205454 131.31867337226868


# Way faster - read from an open file, in order
from disk_objectstore.utils import PackedObjectReader, StreamDecompresser
pack_id = 0
total = pack_obj_count[pack_id]
cnt = 0
t = time.time()
pack_path = c._get_pack_path_from_pack_id(str(pack_id))
with open(pack_path, mode='rb') as pack_handle:
    for hashkey, size, offset, length, compressed in tqdm(
            session.query(
                Obj.hashkey, Obj.size, Obj.offset, Obj.length, Obj.compressed).filter(
                    Obj.pack_id == 0).order_by(Obj.offset),
            total=total):
        obj_reader = PackedObjectReader(
            fhandle=pack_handle, offset=offset, length=length
        )
        if compressed:
            obj_reader = StreamDecompresser(obj_reader)
        # Read all content
        obj_reader.read()
        cnt += 1
print(total, cnt, time.time() - t)

#100%|████████████████████████████████| 205454/205454 [00:03<00:00, 63311.74it/s]
#205454 205454 3.246506690979004

#100%|████████████████████████████████| 205454/205454 [00:03<00:00, 62998.29it/s]
#205454 205454 3.2626254558563232


# Let us also compute the hash
def _validate_hashkeys_pack(self, pack_id, callback=None):

    _CHUNKSIZE = 524288
    hash_class = get_hash_cls(self.hash_type)

    total = self._get_cached_session().query(Obj).filter(Obj.pack_id==pack_id).count()
    if callback:
        callback(action='init', value=total)

    pack_path = c._get_pack_path_from_pack_id(str(pack_id))
    with open(pack_path, mode='rb') as pack_handle:
        for hashkey, size, offset, length, compressed in session.query(
                    Obj.hashkey, Obj.size, Obj.offset, Obj.length, Obj.compressed).filter(
                        Obj.pack_id == 0).order_by(Obj.offset):
            obj_reader = PackedObjectReader(
                fhandle=pack_handle, offset=offset, length=length
            )
            if compressed:
                obj_reader = StreamDecompresser(obj_reader)

            # Read and hash all content
            hasher = hash_class()
            computed_size = 0
            while True:
                next_chunk = obj_reader.read(_CHUNKSIZE)
                if not next_chunk:
                    # Empty returned value: EOF
                    break
                hasher.update(next_chunk)
                computed_size += len(next_chunk)
            computed_hash = hasher.hexdigest()

            assert computed_hash == hashkey
            assert computed_size == size

            if callback:
                callback(action='update', value=1)

    if callback:
        callback(action='close')

PROGRESS_BARS = {}
def progress_bar_callback(name, action, value=None):
    global PROGRESS_BARS

    if action == 'init':
        if name in PROGRESS_BARS:
            PROGRESS_BARS[name].close()
        PROGRESS_BARS[name] = tqdm(total=value)
    elif action == 'update':
        if value is None:
            value = 1
        PROGRESS_BARS[name].update(n=value)
    elif action == 'close':
        PROGRESS_BARS[name].close()
        PROGRESS_BARS.pop(name)

t = time.time()
_validate_hashkeys_pack(c, pack_id)
print(time.time() - t)

# 12.779397249221802
# 12.95686936378479


import functools
callback = functools.partial(progress_bar_callback, name='validate')

t = time.time()
_validate_hashkeys_pack(c, pack_id, callback=callback)
print(time.time() - t)

# 100%|████████████████████████████████| 205454/205454 [00:12<00:00, 16978.17it/s]
# 13.186200380325317

# 100%|████████████████████████████████| 205454/205454 [00:12<00:00, 16092.63it/s]
# 13.858833074569702



import tqdm

PROGRESS_BAR = None

def tqdm_callback(action, value=None):
    global PROGRESS_BAR

    if action == 'init':
        if PROGRESS_BAR is not None:
            PROGRESS_BAR.close()
        PROGRESS_BAR = tqdm.tqdm(total=value)
    elif action == 'update':
        if value is None:
            value = 1
        PROGRESS_BAR.update(n=value)
    elif action == 'close':
        PROGRESS_BAR.close()
        PROGRESS_BAR = None

t = time.time()
_validate_hashkeys_pack(c, pack_id, callback=tqdm_callback)
print(time.time() - t)
