databased.dbshell

  1import argparse
  2from datetime import datetime
  3
  4import argshell
  5from griddle import griddy
  6from noiftimer import time_it
  7from pathier import Pathier, Pathish
  8
  9from databased import Databased, __version__, dbparsers
 10from databased.create_shell import create_shell
 11
 12
 13class DBShell(argshell.ArgShell):
 14    _dbpath: Pathier = None  # type: ignore
 15    connection_timeout: float = 10
 16    detect_types: bool = True
 17    enforce_foreign_keys: bool = True
 18    commit_on_close: bool = True
 19    log_dir: Pathier = Pathier.cwd()
 20    intro = f"Starting dbshell v{__version__} (enter help or ? for arg info)...\n"
 21    prompt = f"based>"
 22
 23    @property
 24    def dbpath(self) -> Pathier:
 25        return self._dbpath
 26
 27    @dbpath.setter
 28    def dbpath(self, path: Pathish):
 29        self._dbpath = Pathier(path)
 30        self.prompt = f"{self._dbpath.name}>"
 31
 32    def _DB(self) -> Databased:
 33        return Databased(
 34            self.dbpath,
 35            self.connection_timeout,
 36            self.detect_types,
 37            self.enforce_foreign_keys,
 38            self.commit_on_close,
 39            self.log_dir,
 40        )
 41
 42    @time_it()
 43    def default(self, line: str):
 44        line = line.strip("_")
 45        with self._DB() as db:
 46            self.display(db.query(line))
 47
 48    def display(self, data: list[dict]):
 49        """Print row data to terminal in a grid."""
 50        try:
 51            print(griddy(data, "keys"))
 52        except Exception as e:
 53            print("Could not fit data into grid :(")
 54            print(e)
 55
 56    # Seat
 57
 58    def _show_tables(self, args: argshell.Namespace):
 59        with self._DB() as db:
 60            if args.tables:
 61                tables = [table for table in args.tables if table in db.tables]
 62            else:
 63                tables = db.tables
 64            if tables:
 65                print("Getting database tables...")
 66                info = [
 67                    {
 68                        "Table Name": table,
 69                        "Columns": ", ".join(db.get_columns(table)),
 70                        "Number of Rows": db.count(table) if args.rowcount else "n/a",
 71                    }
 72                    for table in tables
 73                ]
 74                self.display(info)
 75
 76    def _show_views(self, args: argshell.Namespace):
 77        with self._DB() as db:
 78            if args.tables:
 79                views = [view for view in args.tables if view in db.views]
 80            else:
 81                views = db.views
 82            if views:
 83                print("Getting database views...")
 84                info = [
 85                    {
 86                        "View Name": view,
 87                        "Columns": ", ".join(db.get_columns(view)),
 88                        "Number of Rows": db.count(view) if args.rowcount else "n/a",
 89                    }
 90                    for view in views
 91                ]
 92                self.display(info)
 93
 94    @argshell.with_parser(dbparsers.get_add_column_parser)
 95    def do_add_column(self, args: argshell.Namespace):
 96        """Add a new column to the specified tables."""
 97        with self._DB() as db:
 98            db.add_column(args.table, args.column_def)
 99
100    @argshell.with_parser(dbparsers.get_add_table_parser)
101    def do_add_table(self, args: argshell.Namespace):
102        """Add a new table to the database."""
103        with self._DB() as db:
104            db.create_table(args.table, *args.columns)
105
106    @argshell.with_parser(dbparsers.get_backup_parser)
107    @time_it()
108    def do_backup(self, args: argshell.Namespace):
109        """Create a backup of the current db file."""
110        print(f"Creating a back up for {self.dbpath}...")
111        backup_path = self.dbpath.backup(args.timestamp)
112        print("Creating backup is complete.")
113        print(f"Backup path: {backup_path}")
114
115    @argshell.with_parser(dbparsers.get_count_parser)
116    @time_it()
117    def do_count(self, args: argshell.Namespace):
118        """Count the number of matching records."""
119        with self._DB() as db:
120            count = db.count(args.table, args.column, args.where, args.distinct)
121            self.display(
122                [
123                    {
124                        "Table": args.table,
125                        "Column": args.column,
126                        "Distinct": args.distinct,
127                        "Where": args.where,
128                        "Count": count,
129                    }
130                ]
131            )
132
133    def do_customize(self, name: str):
134        """Generate a template file in the current working directory for creating a custom DBShell class.
135        Expects one argument: the name of the custom dbshell.
136        This will be used to name the generated file as well as several components in the file content.
137        """
138        try:
139            create_shell(name)
140        except Exception as e:
141            print(f"{type(e).__name__}: {e}")
142
143    def do_dbpath(self, _: str):
144        """Print the .db file in use."""
145        print(self.dbpath)
146
147    @argshell.with_parser(dbparsers.get_delete_parser)
148    @time_it()
149    def do_delete(self, args: argshell.Namespace):
150        """Delete rows from the database.
151
152        Syntax:
153        >>> delete {table} {where}
154        >>> based>delete users "username LIKE '%chungus%"
155
156        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
157        print("Deleting records...")
158        with self._DB() as db:
159            num_rows = db.delete(args.table, args.where)
160            print(f"Deleted {num_rows} rows from {args.table} table.")
161
162    def do_describe(self, tables: str):
163        """Describe each given table or view. If no list is given, all tables and views will be described."""
164        with self._DB() as db:
165            table_list = tables.split() or (db.tables + db.views)
166            for table in table_list:
167                print(f"<{table}>")
168                print(db.to_grid(db.describe(table)))
169
170    @argshell.with_parser(dbparsers.get_drop_column_parser)
171    def do_drop_column(self, args: argshell.Namespace):
172        """Drop the specified column from the specified table."""
173        with self._DB() as db:
174            db.drop_column(args.table, args.column)
175
176    def do_drop_table(self, table: str):
177        """Drop the specified table."""
178        with self._DB() as db:
179            db.drop_table(table)
180
181    @argshell.with_parser(dbparsers.get_dump_parser)
182    @time_it()
183    def do_dump(self, args: argshell.Namespace):
184        """Create `.sql` dump files for the current database."""
185        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
186        if not args.data_only:
187            print("Dumping schema...")
188            with self._DB() as db:
189                db.dump_schema(
190                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
191                )
192        if not args.schema_only:
193            print("Dumping data...")
194            with self._DB() as db:
195                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)
196
197    def do_flush_log(self, _: str):
198        """Clear the log file for this database."""
199        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
200        if not log_path.exists():
201            print(f"No log file at path {log_path}")
202        else:
203            print(f"Flushing log...")
204            log_path.write_text("")
205
206    def do_help(self, args: str):
207        """Display help messages."""
208        super().do_help(args)
209        if args == "":
210            print("Unrecognized commands will be executed as queries.")
211            print(
212                "Use the `query` command explicitly if you don't want to capitalize your key words."
213            )
214            print("All transactions initiated by commands are committed immediately.")
215        print()
216
217    def do_new_db(self, dbname: str):
218        """Create a new, empty database with the given name."""
219        dbpath = Pathier(dbname)
220        self.dbpath = dbpath
221        self.prompt = f"{self.dbpath.name}>"
222
223    def do_properties(self, _: str):
224        """See current database property settings."""
225        for property_ in [
226            "connection_timeout",
227            "detect_types",
228            "enforce_foreign_keys",
229            "commit_on_close",
230            "log_dir",
231        ]:
232            print(f"{property_}: {getattr(self, property_)}")
233
234    @time_it()
235    def do_query(self, query: str):
236        """Execute a query against the current database."""
237        print(f"Executing {query}")
238        with self._DB() as db:
239            results = db.query(query)
240        self.display(results)
241        print(f"{db.cursor.rowcount} affected rows")
242
243    @argshell.with_parser(dbparsers.get_rename_column_parser)
244    def do_rename_column(self, args: argshell.Namespace):
245        """Rename a column."""
246        with self._DB() as db:
247            db.rename_column(args.table, args.column, args.new_name)
248
249    @argshell.with_parser(dbparsers.get_rename_table_parser)
250    def do_rename_table(self, args: argshell.Namespace):
251        """Rename a table."""
252        with self._DB() as db:
253            db.rename_table(args.table, args.new_name)
254
255    def do_restore(self, file: str):
256        """Replace the current db file with the given db backup file."""
257        backup = Pathier(file.strip('"'))
258        if not backup.exists():
259            print(f"{backup} does not exist.")
260        else:
261            print(f"Restoring from {file}...")
262            self.dbpath.write_bytes(backup.read_bytes())
263            print("Restore complete.")
264
265    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
266    def do_scan(self, args: argshell.Namespace):
267        """Scan the current working directory for database files."""
268        dbs = self._scan(args.extensions, args.recursive)
269        for db in dbs:
270            print(db.separate(Pathier.cwd().stem))
271
272    @argshell.with_parser(dbparsers.get_schema_parser)
273    @time_it()
274    def do_schema(self, args: argshell.Namespace):
275        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
276        self._show_tables(args)
277        self._show_views(args)
278
279    @time_it()
280    def do_script(self, path: str):
281        """Execute the given SQL script."""
282        with self._DB() as db:
283            self.display(db.execute_script(path))
284
285    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
286    @time_it()
287    def do_select(self, args: argshell.Namespace):
288        """Execute a SELECT query with the given args."""
289        print(f"Querying {args.table}... ")
290        with self._DB() as db:
291            rows = db.select(
292                table=args.table,
293                columns=args.columns,
294                joins=args.joins,
295                where=args.where,
296                group_by=args.group_by,
297                having=args.Having,
298                order_by=args.order_by,
299                limit=args.limit,
300                exclude_columns=args.exclude_columns,
301            )
302            print(f"Found {len(rows)} rows:")
303            self.display(rows)
304            print(f"{len(rows)} rows from {args.table}")
305
306    def do_set_connection_timeout(self, seconds: str):
307        """Set database connection timeout to this number of seconds."""
308        self.connection_timeout = float(seconds)
309
310    def do_set_detect_types(self, should_detect: str):
311        """Pass a `1` to turn on and a `0` to turn off."""
312        self.detect_types = bool(int(should_detect))
313
314    def do_set_enforce_foreign_keys(self, should_enforce: str):
315        """Pass a `1` to turn on and a `0` to turn off."""
316        self.enforce_foreign_keys = bool(int(should_enforce))
317
318    def do_set_commit_on_close(self, should_commit: str):
319        """Pass a `1` to turn on and a `0` to turn off."""
320        self.commit_on_close = bool(int(should_commit))
321
322    def do_size(self, _: str):
323        """Display the size of the the current db file."""
324        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")
325
326    @argshell.with_parser(dbparsers.get_schema_parser)
327    @time_it()
328    def do_tables(self, args: argshell.Namespace):
329        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
330        self._show_tables(args)
331
332    @argshell.with_parser(dbparsers.get_update_parser)
333    @time_it()
334    def do_update(self, args: argshell.Namespace):
335        """Update a column to a new value.
336
337        Syntax:
338        >>> update {table} {column} {value} {where}
339        >>> based>update users username big_chungus "username = lil_chungus"
340
341        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
342        """
343        print("Updating rows...")
344        with self._DB() as db:
345            num_updates = db.update(args.table, args.column, args.new_value, args.where)
346            print(f"Updated {num_updates} rows in table {args.table}.")
347
348    def do_use(self, dbname: str):
349        """Set which database file to use."""
350        dbpath = Pathier(dbname)
351        if not dbpath.exists():
352            print(f"{dbpath} does not exist.")
353            print(f"Still using {self.dbpath}")
354        elif not dbpath.is_file():
355            print(f"{dbpath} is not a file.")
356            print(f"Still using {self.dbpath}")
357        else:
358            self.dbpath = dbpath
359            self.prompt = f"{self.dbpath.name}>"
360
361    @time_it()
362    def do_vacuum(self, _: str):
363        """Reduce database disk memory."""
364        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
365        print("Vacuuming database...")
366        with self._DB() as db:
367            freedspace = db.vacuum()
368        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
369        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")
370
371    @argshell.with_parser(dbparsers.get_schema_parser)
372    @time_it()
373    def do_views(self, args: argshell.Namespace):
374        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
375        self._show_views(args)
376
377    # Seat
378
379    def _choose_db(self, options: list[Pathier]) -> Pathier:
380        """Prompt the user to select from a list of files."""
381        cwd = Pathier.cwd()
382        paths = [path.separate(cwd.stem) for path in options]
383        while True:
384            print(
385                f"DB options:\n{' '.join([f'({i}) {path}' for i, path in enumerate(paths, 1)])}"
386            )
387            choice = input("Enter the number of the option to use: ")
388            try:
389                index = int(choice)
390                if not 1 <= index <= len(options):
391                    print("Choice out of range.")
392                    continue
393                return options[index - 1]
394            except Exception as e:
395                print(f"{choice} is not a valid option.")
396
397    def _scan(
398        self, extensions: list[str] = [".sqlite3", ".db"], recursive: bool = False
399    ) -> list[Pathier]:
400        cwd = Pathier.cwd()
401        dbs = []
402        globber = cwd.glob
403        if recursive:
404            globber = cwd.rglob
405        for extension in extensions:
406            dbs.extend(list(globber(f"*{extension}")))
407        return dbs
408
409    def preloop(self):
410        """Scan the current directory for a .db file to use.
411        If not found, prompt the user for one or to try again recursively."""
412        if self.dbpath:
413            self.dbpath = Pathier(self.dbpath)
414            print(f"Defaulting to database {self.dbpath}")
415        else:
416            print("Searching for database...")
417            cwd = Pathier.cwd()
418            dbs = self._scan()
419            if len(dbs) == 1:
420                self.dbpath = dbs[0]
421                print(f"Using database {self.dbpath}.")
422            elif dbs:
423                self.dbpath = self._choose_db(dbs)
424            else:
425                print(f"Could not find a .db file in {cwd}.")
426                path = input(
427                    "Enter path to .db file to use or press enter to search again recursively: "
428                )
429                if path:
430                    self.dbpath = Pathier(path)
431                elif not path:
432                    print("Searching recursively...")
433                    dbs = self._scan(recursive=True)
434                    if len(dbs) == 1:
435                        self.dbpath = dbs[0]
436                        print(f"Using database {self.dbpath}.")
437                    elif dbs:
438                        self.dbpath = self._choose_db(dbs)
439                    else:
440                        print("Could not find a .db file.")
441                        self.dbpath = Pathier(input("Enter path to a .db file: "))
442        if not self.dbpath.exists():
443            raise FileNotFoundError(f"{self.dbpath} does not exist.")
444        if not self.dbpath.is_file():
445            raise ValueError(f"{self.dbpath} is not a file.")
446
447
448def get_args() -> argparse.Namespace:
449    parser = argparse.ArgumentParser()
450
451    parser.add_argument(
452        "dbpath",
453        nargs="?",
454        type=str,
455        help=""" The database file to use. If not provided the current working directory will be scanned for database files. """,
456    )
457    args = parser.parse_args()
458
459    return args
460
461
462def main(args: argparse.Namespace | None = None):
463    if not args:
464        args = get_args()
465    dbshell = DBShell()
466    if args.dbpath:
467        dbshell.dbpath = Pathier(args.dbpath)
468    dbshell.cmdloop()
469
470
471if __name__ == "__main__":
472    main(get_args())
class DBShell(argshell.argshell.ArgShell):
 14class DBShell(argshell.ArgShell):
 15    _dbpath: Pathier = None  # type: ignore
 16    connection_timeout: float = 10
 17    detect_types: bool = True
 18    enforce_foreign_keys: bool = True
 19    commit_on_close: bool = True
 20    log_dir: Pathier = Pathier.cwd()
 21    intro = f"Starting dbshell v{__version__} (enter help or ? for arg info)...\n"
 22    prompt = f"based>"
 23
 24    @property
 25    def dbpath(self) -> Pathier:
 26        return self._dbpath
 27
 28    @dbpath.setter
 29    def dbpath(self, path: Pathish):
 30        self._dbpath = Pathier(path)
 31        self.prompt = f"{self._dbpath.name}>"
 32
 33    def _DB(self) -> Databased:
 34        return Databased(
 35            self.dbpath,
 36            self.connection_timeout,
 37            self.detect_types,
 38            self.enforce_foreign_keys,
 39            self.commit_on_close,
 40            self.log_dir,
 41        )
 42
 43    @time_it()
 44    def default(self, line: str):
 45        line = line.strip("_")
 46        with self._DB() as db:
 47            self.display(db.query(line))
 48
 49    def display(self, data: list[dict]):
 50        """Print row data to terminal in a grid."""
 51        try:
 52            print(griddy(data, "keys"))
 53        except Exception as e:
 54            print("Could not fit data into grid :(")
 55            print(e)
 56
 57    # Seat
 58
 59    def _show_tables(self, args: argshell.Namespace):
 60        with self._DB() as db:
 61            if args.tables:
 62                tables = [table for table in args.tables if table in db.tables]
 63            else:
 64                tables = db.tables
 65            if tables:
 66                print("Getting database tables...")
 67                info = [
 68                    {
 69                        "Table Name": table,
 70                        "Columns": ", ".join(db.get_columns(table)),
 71                        "Number of Rows": db.count(table) if args.rowcount else "n/a",
 72                    }
 73                    for table in tables
 74                ]
 75                self.display(info)
 76
 77    def _show_views(self, args: argshell.Namespace):
 78        with self._DB() as db:
 79            if args.tables:
 80                views = [view for view in args.tables if view in db.views]
 81            else:
 82                views = db.views
 83            if views:
 84                print("Getting database views...")
 85                info = [
 86                    {
 87                        "View Name": view,
 88                        "Columns": ", ".join(db.get_columns(view)),
 89                        "Number of Rows": db.count(view) if args.rowcount else "n/a",
 90                    }
 91                    for view in views
 92                ]
 93                self.display(info)
 94
 95    @argshell.with_parser(dbparsers.get_add_column_parser)
 96    def do_add_column(self, args: argshell.Namespace):
 97        """Add a new column to the specified tables."""
 98        with self._DB() as db:
 99            db.add_column(args.table, args.column_def)
100
101    @argshell.with_parser(dbparsers.get_add_table_parser)
102    def do_add_table(self, args: argshell.Namespace):
103        """Add a new table to the database."""
104        with self._DB() as db:
105            db.create_table(args.table, *args.columns)
106
107    @argshell.with_parser(dbparsers.get_backup_parser)
108    @time_it()
109    def do_backup(self, args: argshell.Namespace):
110        """Create a backup of the current db file."""
111        print(f"Creating a back up for {self.dbpath}...")
112        backup_path = self.dbpath.backup(args.timestamp)
113        print("Creating backup is complete.")
114        print(f"Backup path: {backup_path}")
115
116    @argshell.with_parser(dbparsers.get_count_parser)
117    @time_it()
118    def do_count(self, args: argshell.Namespace):
119        """Count the number of matching records."""
120        with self._DB() as db:
121            count = db.count(args.table, args.column, args.where, args.distinct)
122            self.display(
123                [
124                    {
125                        "Table": args.table,
126                        "Column": args.column,
127                        "Distinct": args.distinct,
128                        "Where": args.where,
129                        "Count": count,
130                    }
131                ]
132            )
133
134    def do_customize(self, name: str):
135        """Generate a template file in the current working directory for creating a custom DBShell class.
136        Expects one argument: the name of the custom dbshell.
137        This will be used to name the generated file as well as several components in the file content.
138        """
139        try:
140            create_shell(name)
141        except Exception as e:
142            print(f"{type(e).__name__}: {e}")
143
144    def do_dbpath(self, _: str):
145        """Print the .db file in use."""
146        print(self.dbpath)
147
148    @argshell.with_parser(dbparsers.get_delete_parser)
149    @time_it()
150    def do_delete(self, args: argshell.Namespace):
151        """Delete rows from the database.
152
153        Syntax:
154        >>> delete {table} {where}
155        >>> based>delete users "username LIKE '%chungus%"
156
157        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
158        print("Deleting records...")
159        with self._DB() as db:
160            num_rows = db.delete(args.table, args.where)
161            print(f"Deleted {num_rows} rows from {args.table} table.")
162
163    def do_describe(self, tables: str):
164        """Describe each given table or view. If no list is given, all tables and views will be described."""
165        with self._DB() as db:
166            table_list = tables.split() or (db.tables + db.views)
167            for table in table_list:
168                print(f"<{table}>")
169                print(db.to_grid(db.describe(table)))
170
171    @argshell.with_parser(dbparsers.get_drop_column_parser)
172    def do_drop_column(self, args: argshell.Namespace):
173        """Drop the specified column from the specified table."""
174        with self._DB() as db:
175            db.drop_column(args.table, args.column)
176
177    def do_drop_table(self, table: str):
178        """Drop the specified table."""
179        with self._DB() as db:
180            db.drop_table(table)
181
182    @argshell.with_parser(dbparsers.get_dump_parser)
183    @time_it()
184    def do_dump(self, args: argshell.Namespace):
185        """Create `.sql` dump files for the current database."""
186        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
187        if not args.data_only:
188            print("Dumping schema...")
189            with self._DB() as db:
190                db.dump_schema(
191                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
192                )
193        if not args.schema_only:
194            print("Dumping data...")
195            with self._DB() as db:
196                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)
197
198    def do_flush_log(self, _: str):
199        """Clear the log file for this database."""
200        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
201        if not log_path.exists():
202            print(f"No log file at path {log_path}")
203        else:
204            print(f"Flushing log...")
205            log_path.write_text("")
206
207    def do_help(self, args: str):
208        """Display help messages."""
209        super().do_help(args)
210        if args == "":
211            print("Unrecognized commands will be executed as queries.")
212            print(
213                "Use the `query` command explicitly if you don't want to capitalize your key words."
214            )
215            print("All transactions initiated by commands are committed immediately.")
216        print()
217
218    def do_new_db(self, dbname: str):
219        """Create a new, empty database with the given name."""
220        dbpath = Pathier(dbname)
221        self.dbpath = dbpath
222        self.prompt = f"{self.dbpath.name}>"
223
224    def do_properties(self, _: str):
225        """See current database property settings."""
226        for property_ in [
227            "connection_timeout",
228            "detect_types",
229            "enforce_foreign_keys",
230            "commit_on_close",
231            "log_dir",
232        ]:
233            print(f"{property_}: {getattr(self, property_)}")
234
235    @time_it()
236    def do_query(self, query: str):
237        """Execute a query against the current database."""
238        print(f"Executing {query}")
239        with self._DB() as db:
240            results = db.query(query)
241        self.display(results)
242        print(f"{db.cursor.rowcount} affected rows")
243
244    @argshell.with_parser(dbparsers.get_rename_column_parser)
245    def do_rename_column(self, args: argshell.Namespace):
246        """Rename a column."""
247        with self._DB() as db:
248            db.rename_column(args.table, args.column, args.new_name)
249
250    @argshell.with_parser(dbparsers.get_rename_table_parser)
251    def do_rename_table(self, args: argshell.Namespace):
252        """Rename a table."""
253        with self._DB() as db:
254            db.rename_table(args.table, args.new_name)
255
256    def do_restore(self, file: str):
257        """Replace the current db file with the given db backup file."""
258        backup = Pathier(file.strip('"'))
259        if not backup.exists():
260            print(f"{backup} does not exist.")
261        else:
262            print(f"Restoring from {file}...")
263            self.dbpath.write_bytes(backup.read_bytes())
264            print("Restore complete.")
265
266    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
267    def do_scan(self, args: argshell.Namespace):
268        """Scan the current working directory for database files."""
269        dbs = self._scan(args.extensions, args.recursive)
270        for db in dbs:
271            print(db.separate(Pathier.cwd().stem))
272
273    @argshell.with_parser(dbparsers.get_schema_parser)
274    @time_it()
275    def do_schema(self, args: argshell.Namespace):
276        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
277        self._show_tables(args)
278        self._show_views(args)
279
280    @time_it()
281    def do_script(self, path: str):
282        """Execute the given SQL script."""
283        with self._DB() as db:
284            self.display(db.execute_script(path))
285
286    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
287    @time_it()
288    def do_select(self, args: argshell.Namespace):
289        """Execute a SELECT query with the given args."""
290        print(f"Querying {args.table}... ")
291        with self._DB() as db:
292            rows = db.select(
293                table=args.table,
294                columns=args.columns,
295                joins=args.joins,
296                where=args.where,
297                group_by=args.group_by,
298                having=args.Having,
299                order_by=args.order_by,
300                limit=args.limit,
301                exclude_columns=args.exclude_columns,
302            )
303            print(f"Found {len(rows)} rows:")
304            self.display(rows)
305            print(f"{len(rows)} rows from {args.table}")
306
307    def do_set_connection_timeout(self, seconds: str):
308        """Set database connection timeout to this number of seconds."""
309        self.connection_timeout = float(seconds)
310
311    def do_set_detect_types(self, should_detect: str):
312        """Pass a `1` to turn on and a `0` to turn off."""
313        self.detect_types = bool(int(should_detect))
314
315    def do_set_enforce_foreign_keys(self, should_enforce: str):
316        """Pass a `1` to turn on and a `0` to turn off."""
317        self.enforce_foreign_keys = bool(int(should_enforce))
318
319    def do_set_commit_on_close(self, should_commit: str):
320        """Pass a `1` to turn on and a `0` to turn off."""
321        self.commit_on_close = bool(int(should_commit))
322
323    def do_size(self, _: str):
324        """Display the size of the the current db file."""
325        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")
326
327    @argshell.with_parser(dbparsers.get_schema_parser)
328    @time_it()
329    def do_tables(self, args: argshell.Namespace):
330        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
331        self._show_tables(args)
332
333    @argshell.with_parser(dbparsers.get_update_parser)
334    @time_it()
335    def do_update(self, args: argshell.Namespace):
336        """Update a column to a new value.
337
338        Syntax:
339        >>> update {table} {column} {value} {where}
340        >>> based>update users username big_chungus "username = lil_chungus"
341
342        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
343        """
344        print("Updating rows...")
345        with self._DB() as db:
346            num_updates = db.update(args.table, args.column, args.new_value, args.where)
347            print(f"Updated {num_updates} rows in table {args.table}.")
348
349    def do_use(self, dbname: str):
350        """Set which database file to use."""
351        dbpath = Pathier(dbname)
352        if not dbpath.exists():
353            print(f"{dbpath} does not exist.")
354            print(f"Still using {self.dbpath}")
355        elif not dbpath.is_file():
356            print(f"{dbpath} is not a file.")
357            print(f"Still using {self.dbpath}")
358        else:
359            self.dbpath = dbpath
360            self.prompt = f"{self.dbpath.name}>"
361
362    @time_it()
363    def do_vacuum(self, _: str):
364        """Reduce database disk memory."""
365        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
366        print("Vacuuming database...")
367        with self._DB() as db:
368            freedspace = db.vacuum()
369        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
370        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")
371
372    @argshell.with_parser(dbparsers.get_schema_parser)
373    @time_it()
374    def do_views(self, args: argshell.Namespace):
375        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
376        self._show_views(args)
377
378    # Seat
379
380    def _choose_db(self, options: list[Pathier]) -> Pathier:
381        """Prompt the user to select from a list of files."""
382        cwd = Pathier.cwd()
383        paths = [path.separate(cwd.stem) for path in options]
384        while True:
385            print(
386                f"DB options:\n{' '.join([f'({i}) {path}' for i, path in enumerate(paths, 1)])}"
387            )
388            choice = input("Enter the number of the option to use: ")
389            try:
390                index = int(choice)
391                if not 1 <= index <= len(options):
392                    print("Choice out of range.")
393                    continue
394                return options[index - 1]
395            except Exception as e:
396                print(f"{choice} is not a valid option.")
397
398    def _scan(
399        self, extensions: list[str] = [".sqlite3", ".db"], recursive: bool = False
400    ) -> list[Pathier]:
401        cwd = Pathier.cwd()
402        dbs = []
403        globber = cwd.glob
404        if recursive:
405            globber = cwd.rglob
406        for extension in extensions:
407            dbs.extend(list(globber(f"*{extension}")))
408        return dbs
409
410    def preloop(self):
411        """Scan the current directory for a .db file to use.
412        If not found, prompt the user for one or to try again recursively."""
413        if self.dbpath:
414            self.dbpath = Pathier(self.dbpath)
415            print(f"Defaulting to database {self.dbpath}")
416        else:
417            print("Searching for database...")
418            cwd = Pathier.cwd()
419            dbs = self._scan()
420            if len(dbs) == 1:
421                self.dbpath = dbs[0]
422                print(f"Using database {self.dbpath}.")
423            elif dbs:
424                self.dbpath = self._choose_db(dbs)
425            else:
426                print(f"Could not find a .db file in {cwd}.")
427                path = input(
428                    "Enter path to .db file to use or press enter to search again recursively: "
429                )
430                if path:
431                    self.dbpath = Pathier(path)
432                elif not path:
433                    print("Searching recursively...")
434                    dbs = self._scan(recursive=True)
435                    if len(dbs) == 1:
436                        self.dbpath = dbs[0]
437                        print(f"Using database {self.dbpath}.")
438                    elif dbs:
439                        self.dbpath = self._choose_db(dbs)
440                    else:
441                        print("Could not find a .db file.")
442                        self.dbpath = Pathier(input("Enter path to a .db file: "))
443        if not self.dbpath.exists():
444            raise FileNotFoundError(f"{self.dbpath} does not exist.")
445        if not self.dbpath.is_file():
446            raise ValueError(f"{self.dbpath} is not a file.")

Subclass this to create custom ArgShells.

@time_it()
def default(self, line: str):
43    @time_it()
44    def default(self, line: str):
45        line = line.strip("_")
46        with self._DB() as db:
47            self.display(db.query(line))

Called on an input line when the command prefix is not recognized.

If this method is not overridden, it prints an error message and returns.

def display(self, data: list[dict]):
49    def display(self, data: list[dict]):
50        """Print row data to terminal in a grid."""
51        try:
52            print(griddy(data, "keys"))
53        except Exception as e:
54            print("Could not fit data into grid :(")
55            print(e)

Print row data to terminal in a grid.

@argshell.with_parser(dbparsers.get_add_column_parser)
def do_add_column(self, args: argshell.argshell.Namespace):
95    @argshell.with_parser(dbparsers.get_add_column_parser)
96    def do_add_column(self, args: argshell.Namespace):
97        """Add a new column to the specified tables."""
98        with self._DB() as db:
99            db.add_column(args.table, args.column_def)

Add a new column to the specified tables.

@argshell.with_parser(dbparsers.get_add_table_parser)
def do_add_table(self, args: argshell.argshell.Namespace):
101    @argshell.with_parser(dbparsers.get_add_table_parser)
102    def do_add_table(self, args: argshell.Namespace):
103        """Add a new table to the database."""
104        with self._DB() as db:
105            db.create_table(args.table, *args.columns)

Add a new table to the database.

@argshell.with_parser(dbparsers.get_backup_parser)
@time_it()
def do_backup(self, args: argshell.argshell.Namespace):
107    @argshell.with_parser(dbparsers.get_backup_parser)
108    @time_it()
109    def do_backup(self, args: argshell.Namespace):
110        """Create a backup of the current db file."""
111        print(f"Creating a back up for {self.dbpath}...")
112        backup_path = self.dbpath.backup(args.timestamp)
113        print("Creating backup is complete.")
114        print(f"Backup path: {backup_path}")

Create a backup of the current db file.

@argshell.with_parser(dbparsers.get_count_parser)
@time_it()
def do_count(self, args: argshell.argshell.Namespace):
116    @argshell.with_parser(dbparsers.get_count_parser)
117    @time_it()
118    def do_count(self, args: argshell.Namespace):
119        """Count the number of matching records."""
120        with self._DB() as db:
121            count = db.count(args.table, args.column, args.where, args.distinct)
122            self.display(
123                [
124                    {
125                        "Table": args.table,
126                        "Column": args.column,
127                        "Distinct": args.distinct,
128                        "Where": args.where,
129                        "Count": count,
130                    }
131                ]
132            )

Count the number of matching records.

def do_customize(self, name: str):
134    def do_customize(self, name: str):
135        """Generate a template file in the current working directory for creating a custom DBShell class.
136        Expects one argument: the name of the custom dbshell.
137        This will be used to name the generated file as well as several components in the file content.
138        """
139        try:
140            create_shell(name)
141        except Exception as e:
142            print(f"{type(e).__name__}: {e}")

Generate a template file in the current working directory for creating a custom DBShell class. Expects one argument: the name of the custom dbshell. This will be used to name the generated file as well as several components in the file content.

def do_dbpath(self, _: str):
144    def do_dbpath(self, _: str):
145        """Print the .db file in use."""
146        print(self.dbpath)

Print the .db file in use.

@argshell.with_parser(dbparsers.get_delete_parser)
@time_it()
def do_delete(self, args: argshell.argshell.Namespace):
148    @argshell.with_parser(dbparsers.get_delete_parser)
149    @time_it()
150    def do_delete(self, args: argshell.Namespace):
151        """Delete rows from the database.
152
153        Syntax:
154        >>> delete {table} {where}
155        >>> based>delete users "username LIKE '%chungus%"
156
157        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
158        print("Deleting records...")
159        with self._DB() as db:
160            num_rows = db.delete(args.table, args.where)
161            print(f"Deleted {num_rows} rows from {args.table} table.")

Delete rows from the database.

Syntax:

>>> delete {table} {where}
>>> based>delete users "username LIKE '%chungus%"

^will delete all rows in the 'users' table whose username contains 'chungus'^

def do_describe(self, tables: str):
163    def do_describe(self, tables: str):
164        """Describe each given table or view. If no list is given, all tables and views will be described."""
165        with self._DB() as db:
166            table_list = tables.split() or (db.tables + db.views)
167            for table in table_list:
168                print(f"<{table}>")
169                print(db.to_grid(db.describe(table)))

Describe each given table or view. If no list is given, all tables and views will be described.

@argshell.with_parser(dbparsers.get_drop_column_parser)
def do_drop_column(self, args: argshell.argshell.Namespace):
171    @argshell.with_parser(dbparsers.get_drop_column_parser)
172    def do_drop_column(self, args: argshell.Namespace):
173        """Drop the specified column from the specified table."""
174        with self._DB() as db:
175            db.drop_column(args.table, args.column)

Drop the specified column from the specified table.

def do_drop_table(self, table: str):
177    def do_drop_table(self, table: str):
178        """Drop the specified table."""
179        with self._DB() as db:
180            db.drop_table(table)

Drop the specified table.

@argshell.with_parser(dbparsers.get_dump_parser)
@time_it()
def do_dump(self, args: argshell.argshell.Namespace):
182    @argshell.with_parser(dbparsers.get_dump_parser)
183    @time_it()
184    def do_dump(self, args: argshell.Namespace):
185        """Create `.sql` dump files for the current database."""
186        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
187        if not args.data_only:
188            print("Dumping schema...")
189            with self._DB() as db:
190                db.dump_schema(
191                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
192                )
193        if not args.schema_only:
194            print("Dumping data...")
195            with self._DB() as db:
196                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)

Create .sql dump files for the current database.

def do_flush_log(self, _: str):
198    def do_flush_log(self, _: str):
199        """Clear the log file for this database."""
200        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
201        if not log_path.exists():
202            print(f"No log file at path {log_path}")
203        else:
204            print(f"Flushing log...")
205            log_path.write_text("")

Clear the log file for this database.

def do_help(self, args: str):
207    def do_help(self, args: str):
208        """Display help messages."""
209        super().do_help(args)
210        if args == "":
211            print("Unrecognized commands will be executed as queries.")
212            print(
213                "Use the `query` command explicitly if you don't want to capitalize your key words."
214            )
215            print("All transactions initiated by commands are committed immediately.")
216        print()

Display help messages.

def do_new_db(self, dbname: str):
218    def do_new_db(self, dbname: str):
219        """Create a new, empty database with the given name."""
220        dbpath = Pathier(dbname)
221        self.dbpath = dbpath
222        self.prompt = f"{self.dbpath.name}>"

Create a new, empty database with the given name.

def do_properties(self, _: str):
224    def do_properties(self, _: str):
225        """See current database property settings."""
226        for property_ in [
227            "connection_timeout",
228            "detect_types",
229            "enforce_foreign_keys",
230            "commit_on_close",
231            "log_dir",
232        ]:
233            print(f"{property_}: {getattr(self, property_)}")

See current database property settings.

@time_it()
def do_query(self, query: str):
235    @time_it()
236    def do_query(self, query: str):
237        """Execute a query against the current database."""
238        print(f"Executing {query}")
239        with self._DB() as db:
240            results = db.query(query)
241        self.display(results)
242        print(f"{db.cursor.rowcount} affected rows")

Execute a query against the current database.

@argshell.with_parser(dbparsers.get_rename_column_parser)
def do_rename_column(self, args: argshell.argshell.Namespace):
244    @argshell.with_parser(dbparsers.get_rename_column_parser)
245    def do_rename_column(self, args: argshell.Namespace):
246        """Rename a column."""
247        with self._DB() as db:
248            db.rename_column(args.table, args.column, args.new_name)

Rename a column.

@argshell.with_parser(dbparsers.get_rename_table_parser)
def do_rename_table(self, args: argshell.argshell.Namespace):
250    @argshell.with_parser(dbparsers.get_rename_table_parser)
251    def do_rename_table(self, args: argshell.Namespace):
252        """Rename a table."""
253        with self._DB() as db:
254            db.rename_table(args.table, args.new_name)

Rename a table.

def do_restore(self, file: str):
256    def do_restore(self, file: str):
257        """Replace the current db file with the given db backup file."""
258        backup = Pathier(file.strip('"'))
259        if not backup.exists():
260            print(f"{backup} does not exist.")
261        else:
262            print(f"Restoring from {file}...")
263            self.dbpath.write_bytes(backup.read_bytes())
264            print("Restore complete.")

Replace the current db file with the given db backup file.

@argshell.with_parser(dbparsers.get_scan_dbs_parser)
def do_scan(self, args: argshell.argshell.Namespace):
266    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
267    def do_scan(self, args: argshell.Namespace):
268        """Scan the current working directory for database files."""
269        dbs = self._scan(args.extensions, args.recursive)
270        for db in dbs:
271            print(db.separate(Pathier.cwd().stem))

Scan the current working directory for database files.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_schema(self, args: argshell.argshell.Namespace):
273    @argshell.with_parser(dbparsers.get_schema_parser)
274    @time_it()
275    def do_schema(self, args: argshell.Namespace):
276        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
277        self._show_tables(args)
278        self._show_views(args)

Print out the names of the database tables and views, their columns, and, optionally, the number of rows.

@time_it()
def do_script(self, path: str):
280    @time_it()
281    def do_script(self, path: str):
282        """Execute the given SQL script."""
283        with self._DB() as db:
284            self.display(db.execute_script(path))

Execute the given SQL script.

@argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
@time_it()
def do_select(self, args: argshell.argshell.Namespace):
286    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
287    @time_it()
288    def do_select(self, args: argshell.Namespace):
289        """Execute a SELECT query with the given args."""
290        print(f"Querying {args.table}... ")
291        with self._DB() as db:
292            rows = db.select(
293                table=args.table,
294                columns=args.columns,
295                joins=args.joins,
296                where=args.where,
297                group_by=args.group_by,
298                having=args.Having,
299                order_by=args.order_by,
300                limit=args.limit,
301                exclude_columns=args.exclude_columns,
302            )
303            print(f"Found {len(rows)} rows:")
304            self.display(rows)
305            print(f"{len(rows)} rows from {args.table}")

Execute a SELECT query with the given args.

def do_set_connection_timeout(self, seconds: str):
307    def do_set_connection_timeout(self, seconds: str):
308        """Set database connection timeout to this number of seconds."""
309        self.connection_timeout = float(seconds)

Set database connection timeout to this number of seconds.

def do_set_detect_types(self, should_detect: str):
311    def do_set_detect_types(self, should_detect: str):
312        """Pass a `1` to turn on and a `0` to turn off."""
313        self.detect_types = bool(int(should_detect))

Pass a 1 to turn on and a 0 to turn off.

def do_set_enforce_foreign_keys(self, should_enforce: str):
315    def do_set_enforce_foreign_keys(self, should_enforce: str):
316        """Pass a `1` to turn on and a `0` to turn off."""
317        self.enforce_foreign_keys = bool(int(should_enforce))

Pass a 1 to turn on and a 0 to turn off.

def do_set_commit_on_close(self, should_commit: str):
319    def do_set_commit_on_close(self, should_commit: str):
320        """Pass a `1` to turn on and a `0` to turn off."""
321        self.commit_on_close = bool(int(should_commit))

Pass a 1 to turn on and a 0 to turn off.

def do_size(self, _: str):
323    def do_size(self, _: str):
324        """Display the size of the the current db file."""
325        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")

Display the size of the the current db file.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_tables(self, args: argshell.argshell.Namespace):
327    @argshell.with_parser(dbparsers.get_schema_parser)
328    @time_it()
329    def do_tables(self, args: argshell.Namespace):
330        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
331        self._show_tables(args)

Print out the names of the database tables, their columns, and, optionally, the number of rows.

@argshell.with_parser(dbparsers.get_update_parser)
@time_it()
def do_update(self, args: argshell.argshell.Namespace):
333    @argshell.with_parser(dbparsers.get_update_parser)
334    @time_it()
335    def do_update(self, args: argshell.Namespace):
336        """Update a column to a new value.
337
338        Syntax:
339        >>> update {table} {column} {value} {where}
340        >>> based>update users username big_chungus "username = lil_chungus"
341
342        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
343        """
344        print("Updating rows...")
345        with self._DB() as db:
346            num_updates = db.update(args.table, args.column, args.new_value, args.where)
347            print(f"Updated {num_updates} rows in table {args.table}.")

Update a column to a new value.

Syntax:

>>> update {table} {column} {value} {where}
>>> based>update users username big_chungus "username = lil_chungus"

^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^

def do_use(self, dbname: str):
349    def do_use(self, dbname: str):
350        """Set which database file to use."""
351        dbpath = Pathier(dbname)
352        if not dbpath.exists():
353            print(f"{dbpath} does not exist.")
354            print(f"Still using {self.dbpath}")
355        elif not dbpath.is_file():
356            print(f"{dbpath} is not a file.")
357            print(f"Still using {self.dbpath}")
358        else:
359            self.dbpath = dbpath
360            self.prompt = f"{self.dbpath.name}>"

Set which database file to use.

@time_it()
def do_vacuum(self, _: str):
362    @time_it()
363    def do_vacuum(self, _: str):
364        """Reduce database disk memory."""
365        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
366        print("Vacuuming database...")
367        with self._DB() as db:
368            freedspace = db.vacuum()
369        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
370        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")

Reduce database disk memory.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_views(self, args: argshell.argshell.Namespace):
372    @argshell.with_parser(dbparsers.get_schema_parser)
373    @time_it()
374    def do_views(self, args: argshell.Namespace):
375        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
376        self._show_views(args)

Print out the names of the database views, their columns, and, optionally, the number of rows.

def preloop(self):
410    def preloop(self):
411        """Scan the current directory for a .db file to use.
412        If not found, prompt the user for one or to try again recursively."""
413        if self.dbpath:
414            self.dbpath = Pathier(self.dbpath)
415            print(f"Defaulting to database {self.dbpath}")
416        else:
417            print("Searching for database...")
418            cwd = Pathier.cwd()
419            dbs = self._scan()
420            if len(dbs) == 1:
421                self.dbpath = dbs[0]
422                print(f"Using database {self.dbpath}.")
423            elif dbs:
424                self.dbpath = self._choose_db(dbs)
425            else:
426                print(f"Could not find a .db file in {cwd}.")
427                path = input(
428                    "Enter path to .db file to use or press enter to search again recursively: "
429                )
430                if path:
431                    self.dbpath = Pathier(path)
432                elif not path:
433                    print("Searching recursively...")
434                    dbs = self._scan(recursive=True)
435                    if len(dbs) == 1:
436                        self.dbpath = dbs[0]
437                        print(f"Using database {self.dbpath}.")
438                    elif dbs:
439                        self.dbpath = self._choose_db(dbs)
440                    else:
441                        print("Could not find a .db file.")
442                        self.dbpath = Pathier(input("Enter path to a .db file: "))
443        if not self.dbpath.exists():
444            raise FileNotFoundError(f"{self.dbpath} does not exist.")
445        if not self.dbpath.is_file():
446            raise ValueError(f"{self.dbpath} is not a file.")

Scan the current directory for a .db file to use. If not found, prompt the user for one or to try again recursively.

Inherited Members
cmd.Cmd
Cmd
precmd
postcmd
postloop
parseline
onecmd
completedefault
completenames
complete
get_names
complete_help
print_topics
columnize
argshell.argshell.ArgShell
do_quit
do_sys
cmdloop
emptyline
def get_args() -> argparse.Namespace:
449def get_args() -> argparse.Namespace:
450    parser = argparse.ArgumentParser()
451
452    parser.add_argument(
453        "dbpath",
454        nargs="?",
455        type=str,
456        help=""" The database file to use. If not provided the current working directory will be scanned for database files. """,
457    )
458    args = parser.parse_args()
459
460    return args
def main(args: argparse.Namespace | None = None):
463def main(args: argparse.Namespace | None = None):
464    if not args:
465        args = get_args()
466    dbshell = DBShell()
467    if args.dbpath:
468        dbshell.dbpath = Pathier(args.dbpath)
469    dbshell.cmdloop()