Source code for ase2sprkkr.asr.database.duplicates

from asr.core import command, argument, option, ASRResult
from datetime import datetime


@command(module='asr.database.duplicates',
         resources='1:20m',
         save_results_file=False)
@argument('databaseout', type=str, required=False)
@argument('database', type=str)
@option('-f', '--filterstring',
        help='List of keys denoting the priority of picking'
        ' candidates among possible duplicates.',
        type=str)
@option('-c', '--comparison-keys',
        help='Keys that have to be identical for materials to be identical.',
        type=str)
@option('-r', '--rmsd-tol', help='RMSD tolerance.', type=float)
@option('--skip-distance-calc', default=False, is_flag=True,
        help="Skip distance calculation. Only match structures "
        "based on their reduced formula and comparison_keys.")
def main(database: str,
         databaseout: str = None,
         filterstring: str = '<=natoms,<energy',
         comparison_keys: str = '',
         rmsd_tol: float = 0.3,
         skip_distance_calc: bool = False) -> ASRResult:
    """Filter out duplicates of a database.

    Parameters
    ----------
    database : str
        Database to be analyzed for duplicates.
    databaseout : str
        Filename of new database with duplicates removed.
    filterstring : str
        Comma separated string of filters. A simple filter could be '<energy'
        which only pick a material if no other material with lower energy
        exists (in other words: chose the lowest energy materials). '<' means
        'smallest'. Other accepted operators are {'<=', '>=', '>', '<', '=='}.
        Additional filters can be added to construct more complex filters,
        i.e., '<energy,<=natoms' means that a material is only picked if no
        other materials with lower energy AND fewer or same number of atoms
        exists.
    comparison_keys : str
        Comma separated string of keys that should be identical
        between rows to be compared. Eg. 'magstate,natoms'.
    rmsd_tol : float
        Tolerance on RMSD between materials for them to be considered
        to be duplicates.
    skip_distance_calc : bool
        If true, only use reduced formula and comparison_keys to match
        structures. Skip calculating distances between structures. The
        output rmsd's will be 0 for matching structures.

    Returns
    -------
    dict
        Keys:
            - ``duplicate_groups``: Dict containing all duplicate groups.
              The key of each group is the uid of the prioritized candidate
              of the group.

    """
    from ase.db import connect
    from asr.core import read_json
    from asr.database.rmsd import main as rmsd
    from asr.utils import timed_print
    assert database != databaseout, \
        'You cannot read and write from the same database.'

    ops_and_keys = parse_filter_string(filterstring)

    if not rmsd.done:
        rmsd(database, comparison_keys=comparison_keys,
             skip_distance_calc=skip_distance_calc)
    rmsd_results = read_json('results-asr.database.rmsd.json')
    rmsd_by_id = rmsd_results['rmsd_by_id']
    uid_key = rmsd_results['uid_key']
    duplicate_groups = []
    db = connect(database)
    exclude_uids = set()
    already_checked_uids = set()
    nrmsd = len(rmsd_by_id)
    rows = {}

    for row in db.select(include_data=False):
        rows[row.get(uid_key)] = row

    print('Filtering materials...')
    for irmsd, (uid, rmsd_dict) in enumerate(rmsd_by_id.items()):
        if uid in already_checked_uids:
            continue
        now = datetime.now()
        timed_print(f'{now:%H:%M:%S}: {irmsd}/{nrmsd}', wait=30)
        duplicate_uids = find_duplicate_group(uid, rmsd_by_id, rmsd_tol)

        # Pick the preferred row according to filterstring
        include = filter_uids(rows, duplicate_uids,
                              ops_and_keys, uid_key)
        # Book keeping
        already_checked_uids.update(duplicate_uids)

        exclude = duplicate_uids - include
        if exclude:
            exclude_uids.update(exclude)
            duplicate_groups.append({'exclude': list(exclude),
                                     'include': list(include)})

    if databaseout is not None:
        nmat = len(rows)
        with connect(databaseout) as filtereddb:
            for uid, row in rows.items():
                now = datetime.now()
                timed_print(f'{now:%H:%M:%S}: {row.id}/{nmat}', wait=30)

                if uid in exclude_uids:
                    continue
                filtereddb.write(atoms=row.toatoms(),
                                 data=row.data,
                                 **row.key_value_pairs)

        filtereddb.metadata = db.metadata

    filterkeys = [key for _, key in ops_and_keys]
    for ig, group in enumerate(duplicate_groups):
        include = group['include']
        exclude = group['exclude']
        max_rmsd = 0
        for uid in include + exclude:
            max_rmsd = max([max_rmsd,
                            max(value for value in rmsd_by_id[uid].values()
                                if value is not None and value < rmsd_tol)])
        print(f'Group #{ig} max_rmsd={max_rmsd}')
        print('    Excluding:')
        for uid in exclude:
            row = rows[uid]
            print(f'        {uid} '
                  + ' '.join(f'{key}=' + str(row.get(key)) for key in filterkeys))
        print('    Including:')
        for uid in include:
            row = rows[uid]
            print(f'        {uid} '
                  + ' '.join(f'{key}=' + str(row.get(key)) for key in filterkeys))

    print(f'Excluded {len(exclude_uids)} materials.')
    return {'duplicate_groups': duplicate_groups,
            'duplicate_uids': list(exclude_uids)}


[docs] def compare(value1, value2, comparator): """Return value1 {comparator} value2.""" if comparator == '<=': return value1 <= value2 elif comparator == '>=': return value1 >= value2 elif comparator == '<': return value1 < value2 elif comparator == '>': return value1 > value2 elif comparator == '==': return value1 == value2
[docs] def filter_uids(all_rows, duplicate_ids, ops_and_keys, uid_key): """Get most important rows according to filterstring. Parameters ---------- all_rows: dict Dictionary with key=uid and value=row. duplicate_ids: iterable Set of possible duplicate materials. ops_and_keys: List[Tuple(str, str)] List of filters where the first element of the tuple is the comparison operator and the second is the to compare i.e.: [('<', 'energy')]. Other accepted operators are {'<=', '>=', '>', '<', '=='}. Additional filters can be added to construct more complex filters, i.e., `[('<', 'energy'), ('<=', 'natoms')]` means that a material is only picked if no other materials with lower energy AND fewer or same number of atoms exists. uid_key: str The UID key of the database connection which the duplicate_ids parameters are refererring to. Returns ------- filtered_uids: `set` Set of filtered uids. """ rows = [all_rows[uid] for uid in duplicate_ids] filtered_uids = set() for candidaterow in rows: better_candidates = { row for row in rows if all(compare(row[key], candidaterow[key], op) for op, key in ops_and_keys)} if not better_candidates: filtered_uids.add(candidaterow.get(f'{uid_key}')) return filtered_uids
[docs] def parse_filter_string(filterstring): """Parse a comma separated filter string. Parameters ---------- filterstring: str Comma separated filter string, i.e. '<energy,<=natoms' Returns ------- ops_and_keys: List[Tuple(str, str)] For the above example would return [('<', 'energy'), ('<=', 'natoms')]. """ filters = filterstring.split(',') sorts = ['<=', '>=', '==', '>', '<'] ops_and_keys = [] for filt in filters: for op in sorts: if filt.startswith(op): break else: raise ValueError(f'Unknown sorting operator in filterstring={filt}.') key = filt[len(op):] ops_and_keys.append((op, key)) return ops_and_keys
[docs] def find_duplicate_group(uid, rmsd_by_id, rmsd_tol, already_considered_uids=None): if already_considered_uids is None: already_considered_uids = {uid} else: already_considered_uids.add(uid) duplicate_uids = set(key for key, value in rmsd_by_id[uid].items() if value is not None and value < rmsd_tol) new_uids = duplicate_uids - already_considered_uids if new_uids: for new_uid in new_uids: find_duplicate_group(new_uid, rmsd_by_id, rmsd_tol, already_considered_uids=already_considered_uids) return already_considered_uids
if __name__ == '__main__': main.cli()