databased.dbshell

  1import argparse
  2from datetime import datetime
  3
  4import argshell
  5from griddle import griddy
  6from noiftimer import time_it
  7from pathier import Pathier, Pathish
  8
  9from databased import Databased, __version__, dbparsers
 10from databased.create_shell import create_shell
 11
 12
 13class DBShell(argshell.ArgShell):
 14    _dbpath: Pathier = None  # type: ignore
 15    connection_timeout: float = 10
 16    detect_types: bool = True
 17    enforce_foreign_keys: bool = True
 18    intro = f"Starting dbshell v{__version__} (enter help or ? for arg info)...\n"
 19    prompt = f"based>"
 20
 21    @property
 22    def dbpath(self) -> Pathier:
 23        return self._dbpath
 24
 25    @dbpath.setter
 26    def dbpath(self, path: Pathish):
 27        self._dbpath = Pathier(path)
 28        self.prompt = f"{self._dbpath.name}>"
 29
 30    def _DB(self) -> Databased:
 31        return Databased(
 32            self.dbpath,
 33            self.connection_timeout,
 34            self.detect_types,
 35            self.enforce_foreign_keys,
 36        )
 37
 38    @time_it()
 39    def default(self, line: str):
 40        line = line.strip("_")
 41        with self._DB() as db:
 42            self.display(db.query(line))
 43
 44    def display(self, data: list[dict]):
 45        """Print row data to terminal in a grid."""
 46        try:
 47            print(griddy(data, "keys"))
 48        except Exception as e:
 49            print("Could not fit data into grid :(")
 50            print(e)
 51
 52    # Seat
 53
 54    def _show_tables(self, args: argshell.Namespace):
 55        with self._DB() as db:
 56            if args.tables:
 57                tables = [table for table in args.tables if table in db.tables]
 58            else:
 59                tables = db.tables
 60            if tables:
 61                print("Getting database tables...")
 62                info = [
 63                    {
 64                        "Table Name": table,
 65                        "Columns": ", ".join(db.get_columns(table)),
 66                        "Number of Rows": db.count(table) if args.rowcount else "n/a",
 67                    }
 68                    for table in tables
 69                ]
 70                self.display(info)
 71
 72    def _show_views(self, args: argshell.Namespace):
 73        with self._DB() as db:
 74            if args.tables:
 75                views = [view for view in args.tables if view in db.views]
 76            else:
 77                views = db.views
 78            if views:
 79                print("Getting database views...")
 80                info = [
 81                    {
 82                        "View Name": view,
 83                        "Columns": ", ".join(db.get_columns(view)),
 84                        "Number of Rows": db.count(view) if args.rowcount else "n/a",
 85                    }
 86                    for view in views
 87                ]
 88                self.display(info)
 89
 90    @argshell.with_parser(dbparsers.get_add_column_parser)
 91    def do_add_column(self, args: argshell.Namespace):
 92        """Add a new column to the specified tables."""
 93        with self._DB() as db:
 94            db.add_column(args.table, args.column_def)
 95
 96    @argshell.with_parser(dbparsers.get_add_table_parser)
 97    def do_add_table(self, args: argshell.Namespace):
 98        """Add a new table to the database."""
 99        with self._DB() as db:
100            db.create_table(args.table, *args.columns)
101
102    @argshell.with_parser(dbparsers.get_backup_parser)
103    @time_it()
104    def do_backup(self, args: argshell.Namespace):
105        """Create a backup of the current db file."""
106        print(f"Creating a back up for {self.dbpath}...")
107        backup_path = self.dbpath.backup(args.timestamp)
108        print("Creating backup is complete.")
109        print(f"Backup path: {backup_path}")
110
111    @argshell.with_parser(dbparsers.get_count_parser)
112    @time_it()
113    def do_count(self, args: argshell.Namespace):
114        """Count the number of matching records."""
115        with self._DB() as db:
116            count = db.count(args.table, args.column, args.where, args.distinct)
117            self.display(
118                [
119                    {
120                        "Table": args.table,
121                        "Column": args.column,
122                        "Distinct": args.distinct,
123                        "Where": args.where,
124                        "Count": count,
125                    }
126                ]
127            )
128
129    def do_customize(self, name: str):
130        """Generate a template file in the current working directory for creating a custom DBShell class.
131        Expects one argument: the name of the custom dbshell.
132        This will be used to name the generated file as well as several components in the file content.
133        """
134        try:
135            create_shell(name)
136        except Exception as e:
137            print(f"{type(e).__name__}: {e}")
138
139    def do_dbpath(self, _: str):
140        """Print the .db file in use."""
141        print(self.dbpath)
142
143    @argshell.with_parser(dbparsers.get_delete_parser)
144    @time_it()
145    def do_delete(self, args: argshell.Namespace):
146        """Delete rows from the database.
147
148        Syntax:
149        >>> delete {table} {where}
150        >>> based>delete users "username LIKE '%chungus%"
151
152        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
153        print("Deleting records...")
154        with self._DB() as db:
155            num_rows = db.delete(args.table, args.where)
156            print(f"Deleted {num_rows} rows from {args.table} table.")
157
158    def do_describe(self, tables: str):
159        """Describe each given table or view. If no list is given, all tables and views will be described."""
160        with self._DB() as db:
161            table_list = tables.split() or (db.tables + db.views)
162            for table in table_list:
163                print(f"<{table}>")
164                print(db.to_grid(db.describe(table)))
165
166    @argshell.with_parser(dbparsers.get_drop_column_parser)
167    def do_drop_column(self, args: argshell.Namespace):
168        """Drop the specified column from the specified table."""
169        with self._DB() as db:
170            db.drop_column(args.table, args.column)
171
172    def do_drop_table(self, table: str):
173        """Drop the specified table."""
174        with self._DB() as db:
175            db.drop_table(table)
176
177    @argshell.with_parser(dbparsers.get_dump_parser)
178    @time_it()
179    def do_dump(self, args: argshell.Namespace):
180        """Create `.sql` dump files for the current database."""
181        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
182        if not args.data_only:
183            print("Dumping schema...")
184            with self._DB() as db:
185                db.dump_schema(
186                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
187                )
188        if not args.schema_only:
189            print("Dumping data...")
190            with self._DB() as db:
191                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)
192
193    def do_flush_log(self, _: str):
194        """Clear the log file for this database."""
195        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
196        if not log_path.exists():
197            print(f"No log file at path {log_path}")
198        else:
199            print(f"Flushing log...")
200            log_path.write_text("")
201
202    def do_help(self, args: str):
203        """Display help messages."""
204        super().do_help(args)
205        if args == "":
206            print("Unrecognized commands will be executed as queries.")
207            print(
208                "Use the `query` command explicitly if you don't want to capitalize your key words."
209            )
210            print("All transactions initiated by commands are committed immediately.")
211        print()
212
213    def do_new_db(self, dbname: str):
214        """Create a new, empty database with the given name."""
215        dbpath = Pathier(dbname)
216        self.dbpath = dbpath
217        self.prompt = f"{self.dbpath.name}>"
218
219    def do_properties(self, _: str):
220        """See current database property settings."""
221        for property_ in ["connection_timeout", "detect_types", "enforce_foreign_keys"]:
222            print(f"{property_}: {getattr(self, property_)}")
223
224    @time_it()
225    def do_query(self, query: str):
226        """Execute a query against the current database."""
227        print(f"Executing {query}")
228        with self._DB() as db:
229            results = db.query(query)
230        self.display(results)
231        print(f"{db.cursor.rowcount} affected rows")
232
233    @argshell.with_parser(dbparsers.get_rename_column_parser)
234    def do_rename_column(self, args: argshell.Namespace):
235        """Rename a column."""
236        with self._DB() as db:
237            db.rename_column(args.table, args.column, args.new_name)
238
239    @argshell.with_parser(dbparsers.get_rename_table_parser)
240    def do_rename_table(self, args: argshell.Namespace):
241        """Rename a table."""
242        with self._DB() as db:
243            db.rename_table(args.table, args.new_name)
244
245    def do_restore(self, file: str):
246        """Replace the current db file with the given db backup file."""
247        backup = Pathier(file.strip('"'))
248        if not backup.exists():
249            print(f"{backup} does not exist.")
250        else:
251            print(f"Restoring from {file}...")
252            self.dbpath.write_bytes(backup.read_bytes())
253            print("Restore complete.")
254
255    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
256    def do_scan(self, args: argshell.Namespace):
257        """Scan the current working directory for database files."""
258        dbs = self._scan(args.extensions, args.recursive)
259        for db in dbs:
260            print(db.separate(Pathier.cwd().stem))
261
262    @argshell.with_parser(dbparsers.get_schema_parser)
263    @time_it()
264    def do_schema(self, args: argshell.Namespace):
265        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
266        self._show_tables(args)
267        self._show_views(args)
268
269    @time_it()
270    def do_script(self, path: str):
271        """Execute the given SQL script."""
272        with self._DB() as db:
273            self.display(db.execute_script(path))
274
275    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
276    @time_it()
277    def do_select(self, args: argshell.Namespace):
278        """Execute a SELECT query with the given args."""
279        print(f"Querying {args.table}... ")
280        with self._DB() as db:
281            rows = db.select(
282                table=args.table,
283                columns=args.columns,
284                joins=args.joins,
285                where=args.where,
286                group_by=args.group_by,
287                having=args.Having,
288                order_by=args.order_by,
289                limit=args.limit,
290                exclude_columns=args.exclude_columns,
291            )
292            print(f"Found {len(rows)} rows:")
293            self.display(rows)
294            print(f"{len(rows)} rows from {args.table}")
295
296    def do_set_connection_timeout(self, seconds: str):
297        """Set database connection timeout to this number of seconds."""
298        self.connection_timeout = float(seconds)
299
300    def do_set_detect_types(self, should_detect: str):
301        """Pass a `1` to turn on and a `0` to turn off."""
302        self.detect_types = bool(int(should_detect))
303
304    def do_set_enforce_foreign_keys(self, should_enforce: str):
305        """Pass a `1` to turn on and a `0` to turn off."""
306        self.enforce_foreign_keys = bool(int(should_enforce))
307
308    def do_size(self, _: str):
309        """Display the size of the the current db file."""
310        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")
311
312    @argshell.with_parser(dbparsers.get_schema_parser)
313    @time_it()
314    def do_tables(self, args: argshell.Namespace):
315        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
316        self._show_tables(args)
317
318    @argshell.with_parser(dbparsers.get_update_parser)
319    @time_it()
320    def do_update(self, args: argshell.Namespace):
321        """Update a column to a new value.
322
323        Syntax:
324        >>> update {table} {column} {value} {where}
325        >>> based>update users username big_chungus "username = lil_chungus"
326
327        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
328        """
329        print("Updating rows...")
330        with self._DB() as db:
331            num_updates = db.update(args.table, args.column, args.new_value, args.where)
332            print(f"Updated {num_updates} rows in table {args.table}.")
333
334    def do_use(self, dbname: str):
335        """Set which database file to use."""
336        dbpath = Pathier(dbname)
337        if not dbpath.exists():
338            print(f"{dbpath} does not exist.")
339            print(f"Still using {self.dbpath}")
340        elif not dbpath.is_file():
341            print(f"{dbpath} is not a file.")
342            print(f"Still using {self.dbpath}")
343        else:
344            self.dbpath = dbpath
345            self.prompt = f"{self.dbpath.name}>"
346
347    @time_it()
348    def do_vacuum(self, _: str):
349        """Reduce database disk memory."""
350        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
351        print("Vacuuming database...")
352        with self._DB() as db:
353            freedspace = db.vacuum()
354        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
355        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")
356
357    @argshell.with_parser(dbparsers.get_schema_parser)
358    @time_it()
359    def do_views(self, args: argshell.Namespace):
360        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
361        self._show_views(args)
362
363    # Seat
364
365    def _choose_db(self, options: list[Pathier]) -> Pathier:
366        """Prompt the user to select from a list of files."""
367        cwd = Pathier.cwd()
368        paths = [path.separate(cwd.stem) for path in options]
369        while True:
370            print(
371                f"DB options:\n{' '.join([f'({i}) {path}' for i, path in enumerate(paths, 1)])}"
372            )
373            choice = input("Enter the number of the option to use: ")
374            try:
375                index = int(choice)
376                if not 1 <= index <= len(options):
377                    print("Choice out of range.")
378                    continue
379                return options[index - 1]
380            except Exception as e:
381                print(f"{choice} is not a valid option.")
382
383    def _scan(
384        self, extensions: list[str] = [".sqlite3", ".db"], recursive: bool = False
385    ) -> list[Pathier]:
386        cwd = Pathier.cwd()
387        dbs = []
388        globber = cwd.glob
389        if recursive:
390            globber = cwd.rglob
391        for extension in extensions:
392            dbs.extend(list(globber(f"*{extension}")))
393        return dbs
394
395    def preloop(self):
396        """Scan the current directory for a .db file to use.
397        If not found, prompt the user for one or to try again recursively."""
398        if self.dbpath:
399            self.dbpath = Pathier(self.dbpath)
400            print(f"Defaulting to database {self.dbpath}")
401        else:
402            print("Searching for database...")
403            cwd = Pathier.cwd()
404            dbs = self._scan()
405            if len(dbs) == 1:
406                self.dbpath = dbs[0]
407                print(f"Using database {self.dbpath}.")
408            elif dbs:
409                self.dbpath = self._choose_db(dbs)
410            else:
411                print(f"Could not find a .db file in {cwd}.")
412                path = input(
413                    "Enter path to .db file to use or press enter to search again recursively: "
414                )
415                if path:
416                    self.dbpath = Pathier(path)
417                elif not path:
418                    print("Searching recursively...")
419                    dbs = self._scan(recursive=True)
420                    if len(dbs) == 1:
421                        self.dbpath = dbs[0]
422                        print(f"Using database {self.dbpath}.")
423                    elif dbs:
424                        self.dbpath = self._choose_db(dbs)
425                    else:
426                        print("Could not find a .db file.")
427                        self.dbpath = Pathier(input("Enter path to a .db file: "))
428        if not self.dbpath.exists():
429            raise FileNotFoundError(f"{self.dbpath} does not exist.")
430        if not self.dbpath.is_file():
431            raise ValueError(f"{self.dbpath} is not a file.")
432
433
434def get_args() -> argparse.Namespace:
435    parser = argparse.ArgumentParser()
436
437    parser.add_argument(
438        "dbpath",
439        nargs="?",
440        type=str,
441        help=""" The database file to use. If not provided the current working directory will be scanned for database files. """,
442    )
443    args = parser.parse_args()
444
445    return args
446
447
448def main(args: argparse.Namespace | None = None):
449    if not args:
450        args = get_args()
451    dbshell = DBShell()
452    if args.dbpath:
453        dbshell.dbpath = Pathier(args.dbpath)
454    dbshell.cmdloop()
455
456
457if __name__ == "__main__":
458    main(get_args())
class DBShell(argshell.argshell.ArgShell):
 14class DBShell(argshell.ArgShell):
 15    _dbpath: Pathier = None  # type: ignore
 16    connection_timeout: float = 10
 17    detect_types: bool = True
 18    enforce_foreign_keys: bool = True
 19    intro = f"Starting dbshell v{__version__} (enter help or ? for arg info)...\n"
 20    prompt = f"based>"
 21
 22    @property
 23    def dbpath(self) -> Pathier:
 24        return self._dbpath
 25
 26    @dbpath.setter
 27    def dbpath(self, path: Pathish):
 28        self._dbpath = Pathier(path)
 29        self.prompt = f"{self._dbpath.name}>"
 30
 31    def _DB(self) -> Databased:
 32        return Databased(
 33            self.dbpath,
 34            self.connection_timeout,
 35            self.detect_types,
 36            self.enforce_foreign_keys,
 37        )
 38
 39    @time_it()
 40    def default(self, line: str):
 41        line = line.strip("_")
 42        with self._DB() as db:
 43            self.display(db.query(line))
 44
 45    def display(self, data: list[dict]):
 46        """Print row data to terminal in a grid."""
 47        try:
 48            print(griddy(data, "keys"))
 49        except Exception as e:
 50            print("Could not fit data into grid :(")
 51            print(e)
 52
 53    # Seat
 54
 55    def _show_tables(self, args: argshell.Namespace):
 56        with self._DB() as db:
 57            if args.tables:
 58                tables = [table for table in args.tables if table in db.tables]
 59            else:
 60                tables = db.tables
 61            if tables:
 62                print("Getting database tables...")
 63                info = [
 64                    {
 65                        "Table Name": table,
 66                        "Columns": ", ".join(db.get_columns(table)),
 67                        "Number of Rows": db.count(table) if args.rowcount else "n/a",
 68                    }
 69                    for table in tables
 70                ]
 71                self.display(info)
 72
 73    def _show_views(self, args: argshell.Namespace):
 74        with self._DB() as db:
 75            if args.tables:
 76                views = [view for view in args.tables if view in db.views]
 77            else:
 78                views = db.views
 79            if views:
 80                print("Getting database views...")
 81                info = [
 82                    {
 83                        "View Name": view,
 84                        "Columns": ", ".join(db.get_columns(view)),
 85                        "Number of Rows": db.count(view) if args.rowcount else "n/a",
 86                    }
 87                    for view in views
 88                ]
 89                self.display(info)
 90
 91    @argshell.with_parser(dbparsers.get_add_column_parser)
 92    def do_add_column(self, args: argshell.Namespace):
 93        """Add a new column to the specified tables."""
 94        with self._DB() as db:
 95            db.add_column(args.table, args.column_def)
 96
 97    @argshell.with_parser(dbparsers.get_add_table_parser)
 98    def do_add_table(self, args: argshell.Namespace):
 99        """Add a new table to the database."""
100        with self._DB() as db:
101            db.create_table(args.table, *args.columns)
102
103    @argshell.with_parser(dbparsers.get_backup_parser)
104    @time_it()
105    def do_backup(self, args: argshell.Namespace):
106        """Create a backup of the current db file."""
107        print(f"Creating a back up for {self.dbpath}...")
108        backup_path = self.dbpath.backup(args.timestamp)
109        print("Creating backup is complete.")
110        print(f"Backup path: {backup_path}")
111
112    @argshell.with_parser(dbparsers.get_count_parser)
113    @time_it()
114    def do_count(self, args: argshell.Namespace):
115        """Count the number of matching records."""
116        with self._DB() as db:
117            count = db.count(args.table, args.column, args.where, args.distinct)
118            self.display(
119                [
120                    {
121                        "Table": args.table,
122                        "Column": args.column,
123                        "Distinct": args.distinct,
124                        "Where": args.where,
125                        "Count": count,
126                    }
127                ]
128            )
129
130    def do_customize(self, name: str):
131        """Generate a template file in the current working directory for creating a custom DBShell class.
132        Expects one argument: the name of the custom dbshell.
133        This will be used to name the generated file as well as several components in the file content.
134        """
135        try:
136            create_shell(name)
137        except Exception as e:
138            print(f"{type(e).__name__}: {e}")
139
140    def do_dbpath(self, _: str):
141        """Print the .db file in use."""
142        print(self.dbpath)
143
144    @argshell.with_parser(dbparsers.get_delete_parser)
145    @time_it()
146    def do_delete(self, args: argshell.Namespace):
147        """Delete rows from the database.
148
149        Syntax:
150        >>> delete {table} {where}
151        >>> based>delete users "username LIKE '%chungus%"
152
153        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
154        print("Deleting records...")
155        with self._DB() as db:
156            num_rows = db.delete(args.table, args.where)
157            print(f"Deleted {num_rows} rows from {args.table} table.")
158
159    def do_describe(self, tables: str):
160        """Describe each given table or view. If no list is given, all tables and views will be described."""
161        with self._DB() as db:
162            table_list = tables.split() or (db.tables + db.views)
163            for table in table_list:
164                print(f"<{table}>")
165                print(db.to_grid(db.describe(table)))
166
167    @argshell.with_parser(dbparsers.get_drop_column_parser)
168    def do_drop_column(self, args: argshell.Namespace):
169        """Drop the specified column from the specified table."""
170        with self._DB() as db:
171            db.drop_column(args.table, args.column)
172
173    def do_drop_table(self, table: str):
174        """Drop the specified table."""
175        with self._DB() as db:
176            db.drop_table(table)
177
178    @argshell.with_parser(dbparsers.get_dump_parser)
179    @time_it()
180    def do_dump(self, args: argshell.Namespace):
181        """Create `.sql` dump files for the current database."""
182        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
183        if not args.data_only:
184            print("Dumping schema...")
185            with self._DB() as db:
186                db.dump_schema(
187                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
188                )
189        if not args.schema_only:
190            print("Dumping data...")
191            with self._DB() as db:
192                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)
193
194    def do_flush_log(self, _: str):
195        """Clear the log file for this database."""
196        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
197        if not log_path.exists():
198            print(f"No log file at path {log_path}")
199        else:
200            print(f"Flushing log...")
201            log_path.write_text("")
202
203    def do_help(self, args: str):
204        """Display help messages."""
205        super().do_help(args)
206        if args == "":
207            print("Unrecognized commands will be executed as queries.")
208            print(
209                "Use the `query` command explicitly if you don't want to capitalize your key words."
210            )
211            print("All transactions initiated by commands are committed immediately.")
212        print()
213
214    def do_new_db(self, dbname: str):
215        """Create a new, empty database with the given name."""
216        dbpath = Pathier(dbname)
217        self.dbpath = dbpath
218        self.prompt = f"{self.dbpath.name}>"
219
220    def do_properties(self, _: str):
221        """See current database property settings."""
222        for property_ in ["connection_timeout", "detect_types", "enforce_foreign_keys"]:
223            print(f"{property_}: {getattr(self, property_)}")
224
225    @time_it()
226    def do_query(self, query: str):
227        """Execute a query against the current database."""
228        print(f"Executing {query}")
229        with self._DB() as db:
230            results = db.query(query)
231        self.display(results)
232        print(f"{db.cursor.rowcount} affected rows")
233
234    @argshell.with_parser(dbparsers.get_rename_column_parser)
235    def do_rename_column(self, args: argshell.Namespace):
236        """Rename a column."""
237        with self._DB() as db:
238            db.rename_column(args.table, args.column, args.new_name)
239
240    @argshell.with_parser(dbparsers.get_rename_table_parser)
241    def do_rename_table(self, args: argshell.Namespace):
242        """Rename a table."""
243        with self._DB() as db:
244            db.rename_table(args.table, args.new_name)
245
246    def do_restore(self, file: str):
247        """Replace the current db file with the given db backup file."""
248        backup = Pathier(file.strip('"'))
249        if not backup.exists():
250            print(f"{backup} does not exist.")
251        else:
252            print(f"Restoring from {file}...")
253            self.dbpath.write_bytes(backup.read_bytes())
254            print("Restore complete.")
255
256    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
257    def do_scan(self, args: argshell.Namespace):
258        """Scan the current working directory for database files."""
259        dbs = self._scan(args.extensions, args.recursive)
260        for db in dbs:
261            print(db.separate(Pathier.cwd().stem))
262
263    @argshell.with_parser(dbparsers.get_schema_parser)
264    @time_it()
265    def do_schema(self, args: argshell.Namespace):
266        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
267        self._show_tables(args)
268        self._show_views(args)
269
270    @time_it()
271    def do_script(self, path: str):
272        """Execute the given SQL script."""
273        with self._DB() as db:
274            self.display(db.execute_script(path))
275
276    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
277    @time_it()
278    def do_select(self, args: argshell.Namespace):
279        """Execute a SELECT query with the given args."""
280        print(f"Querying {args.table}... ")
281        with self._DB() as db:
282            rows = db.select(
283                table=args.table,
284                columns=args.columns,
285                joins=args.joins,
286                where=args.where,
287                group_by=args.group_by,
288                having=args.Having,
289                order_by=args.order_by,
290                limit=args.limit,
291                exclude_columns=args.exclude_columns,
292            )
293            print(f"Found {len(rows)} rows:")
294            self.display(rows)
295            print(f"{len(rows)} rows from {args.table}")
296
297    def do_set_connection_timeout(self, seconds: str):
298        """Set database connection timeout to this number of seconds."""
299        self.connection_timeout = float(seconds)
300
301    def do_set_detect_types(self, should_detect: str):
302        """Pass a `1` to turn on and a `0` to turn off."""
303        self.detect_types = bool(int(should_detect))
304
305    def do_set_enforce_foreign_keys(self, should_enforce: str):
306        """Pass a `1` to turn on and a `0` to turn off."""
307        self.enforce_foreign_keys = bool(int(should_enforce))
308
309    def do_size(self, _: str):
310        """Display the size of the the current db file."""
311        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")
312
313    @argshell.with_parser(dbparsers.get_schema_parser)
314    @time_it()
315    def do_tables(self, args: argshell.Namespace):
316        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
317        self._show_tables(args)
318
319    @argshell.with_parser(dbparsers.get_update_parser)
320    @time_it()
321    def do_update(self, args: argshell.Namespace):
322        """Update a column to a new value.
323
324        Syntax:
325        >>> update {table} {column} {value} {where}
326        >>> based>update users username big_chungus "username = lil_chungus"
327
328        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
329        """
330        print("Updating rows...")
331        with self._DB() as db:
332            num_updates = db.update(args.table, args.column, args.new_value, args.where)
333            print(f"Updated {num_updates} rows in table {args.table}.")
334
335    def do_use(self, dbname: str):
336        """Set which database file to use."""
337        dbpath = Pathier(dbname)
338        if not dbpath.exists():
339            print(f"{dbpath} does not exist.")
340            print(f"Still using {self.dbpath}")
341        elif not dbpath.is_file():
342            print(f"{dbpath} is not a file.")
343            print(f"Still using {self.dbpath}")
344        else:
345            self.dbpath = dbpath
346            self.prompt = f"{self.dbpath.name}>"
347
348    @time_it()
349    def do_vacuum(self, _: str):
350        """Reduce database disk memory."""
351        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
352        print("Vacuuming database...")
353        with self._DB() as db:
354            freedspace = db.vacuum()
355        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
356        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")
357
358    @argshell.with_parser(dbparsers.get_schema_parser)
359    @time_it()
360    def do_views(self, args: argshell.Namespace):
361        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
362        self._show_views(args)
363
364    # Seat
365
366    def _choose_db(self, options: list[Pathier]) -> Pathier:
367        """Prompt the user to select from a list of files."""
368        cwd = Pathier.cwd()
369        paths = [path.separate(cwd.stem) for path in options]
370        while True:
371            print(
372                f"DB options:\n{' '.join([f'({i}) {path}' for i, path in enumerate(paths, 1)])}"
373            )
374            choice = input("Enter the number of the option to use: ")
375            try:
376                index = int(choice)
377                if not 1 <= index <= len(options):
378                    print("Choice out of range.")
379                    continue
380                return options[index - 1]
381            except Exception as e:
382                print(f"{choice} is not a valid option.")
383
384    def _scan(
385        self, extensions: list[str] = [".sqlite3", ".db"], recursive: bool = False
386    ) -> list[Pathier]:
387        cwd = Pathier.cwd()
388        dbs = []
389        globber = cwd.glob
390        if recursive:
391            globber = cwd.rglob
392        for extension in extensions:
393            dbs.extend(list(globber(f"*{extension}")))
394        return dbs
395
396    def preloop(self):
397        """Scan the current directory for a .db file to use.
398        If not found, prompt the user for one or to try again recursively."""
399        if self.dbpath:
400            self.dbpath = Pathier(self.dbpath)
401            print(f"Defaulting to database {self.dbpath}")
402        else:
403            print("Searching for database...")
404            cwd = Pathier.cwd()
405            dbs = self._scan()
406            if len(dbs) == 1:
407                self.dbpath = dbs[0]
408                print(f"Using database {self.dbpath}.")
409            elif dbs:
410                self.dbpath = self._choose_db(dbs)
411            else:
412                print(f"Could not find a .db file in {cwd}.")
413                path = input(
414                    "Enter path to .db file to use or press enter to search again recursively: "
415                )
416                if path:
417                    self.dbpath = Pathier(path)
418                elif not path:
419                    print("Searching recursively...")
420                    dbs = self._scan(recursive=True)
421                    if len(dbs) == 1:
422                        self.dbpath = dbs[0]
423                        print(f"Using database {self.dbpath}.")
424                    elif dbs:
425                        self.dbpath = self._choose_db(dbs)
426                    else:
427                        print("Could not find a .db file.")
428                        self.dbpath = Pathier(input("Enter path to a .db file: "))
429        if not self.dbpath.exists():
430            raise FileNotFoundError(f"{self.dbpath} does not exist.")
431        if not self.dbpath.is_file():
432            raise ValueError(f"{self.dbpath} is not a file.")

Subclass this to create custom ArgShells.

@time_it()
def default(self, line: str):
39    @time_it()
40    def default(self, line: str):
41        line = line.strip("_")
42        with self._DB() as db:
43            self.display(db.query(line))

Called on an input line when the command prefix is not recognized.

If this method is not overridden, it prints an error message and returns.

def display(self, data: list[dict]):
45    def display(self, data: list[dict]):
46        """Print row data to terminal in a grid."""
47        try:
48            print(griddy(data, "keys"))
49        except Exception as e:
50            print("Could not fit data into grid :(")
51            print(e)

Print row data to terminal in a grid.

@argshell.with_parser(dbparsers.get_add_column_parser)
def do_add_column(self, args: argshell.argshell.Namespace):
91    @argshell.with_parser(dbparsers.get_add_column_parser)
92    def do_add_column(self, args: argshell.Namespace):
93        """Add a new column to the specified tables."""
94        with self._DB() as db:
95            db.add_column(args.table, args.column_def)

Add a new column to the specified tables.

@argshell.with_parser(dbparsers.get_add_table_parser)
def do_add_table(self, args: argshell.argshell.Namespace):
 97    @argshell.with_parser(dbparsers.get_add_table_parser)
 98    def do_add_table(self, args: argshell.Namespace):
 99        """Add a new table to the database."""
100        with self._DB() as db:
101            db.create_table(args.table, *args.columns)

Add a new table to the database.

@argshell.with_parser(dbparsers.get_backup_parser)
@time_it()
def do_backup(self, args: argshell.argshell.Namespace):
103    @argshell.with_parser(dbparsers.get_backup_parser)
104    @time_it()
105    def do_backup(self, args: argshell.Namespace):
106        """Create a backup of the current db file."""
107        print(f"Creating a back up for {self.dbpath}...")
108        backup_path = self.dbpath.backup(args.timestamp)
109        print("Creating backup is complete.")
110        print(f"Backup path: {backup_path}")

Create a backup of the current db file.

@argshell.with_parser(dbparsers.get_count_parser)
@time_it()
def do_count(self, args: argshell.argshell.Namespace):
112    @argshell.with_parser(dbparsers.get_count_parser)
113    @time_it()
114    def do_count(self, args: argshell.Namespace):
115        """Count the number of matching records."""
116        with self._DB() as db:
117            count = db.count(args.table, args.column, args.where, args.distinct)
118            self.display(
119                [
120                    {
121                        "Table": args.table,
122                        "Column": args.column,
123                        "Distinct": args.distinct,
124                        "Where": args.where,
125                        "Count": count,
126                    }
127                ]
128            )

Count the number of matching records.

def do_customize(self, name: str):
130    def do_customize(self, name: str):
131        """Generate a template file in the current working directory for creating a custom DBShell class.
132        Expects one argument: the name of the custom dbshell.
133        This will be used to name the generated file as well as several components in the file content.
134        """
135        try:
136            create_shell(name)
137        except Exception as e:
138            print(f"{type(e).__name__}: {e}")

Generate a template file in the current working directory for creating a custom DBShell class. Expects one argument: the name of the custom dbshell. This will be used to name the generated file as well as several components in the file content.

def do_dbpath(self, _: str):
140    def do_dbpath(self, _: str):
141        """Print the .db file in use."""
142        print(self.dbpath)

Print the .db file in use.

@argshell.with_parser(dbparsers.get_delete_parser)
@time_it()
def do_delete(self, args: argshell.argshell.Namespace):
144    @argshell.with_parser(dbparsers.get_delete_parser)
145    @time_it()
146    def do_delete(self, args: argshell.Namespace):
147        """Delete rows from the database.
148
149        Syntax:
150        >>> delete {table} {where}
151        >>> based>delete users "username LIKE '%chungus%"
152
153        ^will delete all rows in the 'users' table whose username contains 'chungus'^"""
154        print("Deleting records...")
155        with self._DB() as db:
156            num_rows = db.delete(args.table, args.where)
157            print(f"Deleted {num_rows} rows from {args.table} table.")

Delete rows from the database.

Syntax:

>>> delete {table} {where}
>>> based>delete users "username LIKE '%chungus%"

^will delete all rows in the 'users' table whose username contains 'chungus'^

def do_describe(self, tables: str):
159    def do_describe(self, tables: str):
160        """Describe each given table or view. If no list is given, all tables and views will be described."""
161        with self._DB() as db:
162            table_list = tables.split() or (db.tables + db.views)
163            for table in table_list:
164                print(f"<{table}>")
165                print(db.to_grid(db.describe(table)))

Describe each given table or view. If no list is given, all tables and views will be described.

@argshell.with_parser(dbparsers.get_drop_column_parser)
def do_drop_column(self, args: argshell.argshell.Namespace):
167    @argshell.with_parser(dbparsers.get_drop_column_parser)
168    def do_drop_column(self, args: argshell.Namespace):
169        """Drop the specified column from the specified table."""
170        with self._DB() as db:
171            db.drop_column(args.table, args.column)

Drop the specified column from the specified table.

def do_drop_table(self, table: str):
173    def do_drop_table(self, table: str):
174        """Drop the specified table."""
175        with self._DB() as db:
176            db.drop_table(table)

Drop the specified table.

@argshell.with_parser(dbparsers.get_dump_parser)
@time_it()
def do_dump(self, args: argshell.argshell.Namespace):
178    @argshell.with_parser(dbparsers.get_dump_parser)
179    @time_it()
180    def do_dump(self, args: argshell.Namespace):
181        """Create `.sql` dump files for the current database."""
182        date = datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
183        if not args.data_only:
184            print("Dumping schema...")
185            with self._DB() as db:
186                db.dump_schema(
187                    Pathier.cwd() / f"{db.name}_schema_{date}.sql", args.tables
188                )
189        if not args.schema_only:
190            print("Dumping data...")
191            with self._DB() as db:
192                db.dump_data(Pathier.cwd() / f"{db.name}_data_{date}.sql", args.tables)

Create .sql dump files for the current database.

def do_flush_log(self, _: str):
194    def do_flush_log(self, _: str):
195        """Clear the log file for this database."""
196        log_path = self.dbpath.with_name(self.dbpath.name.replace(".", "") + ".log")
197        if not log_path.exists():
198            print(f"No log file at path {log_path}")
199        else:
200            print(f"Flushing log...")
201            log_path.write_text("")

Clear the log file for this database.

def do_help(self, args: str):
203    def do_help(self, args: str):
204        """Display help messages."""
205        super().do_help(args)
206        if args == "":
207            print("Unrecognized commands will be executed as queries.")
208            print(
209                "Use the `query` command explicitly if you don't want to capitalize your key words."
210            )
211            print("All transactions initiated by commands are committed immediately.")
212        print()

Display help messages.

def do_new_db(self, dbname: str):
214    def do_new_db(self, dbname: str):
215        """Create a new, empty database with the given name."""
216        dbpath = Pathier(dbname)
217        self.dbpath = dbpath
218        self.prompt = f"{self.dbpath.name}>"

Create a new, empty database with the given name.

def do_properties(self, _: str):
220    def do_properties(self, _: str):
221        """See current database property settings."""
222        for property_ in ["connection_timeout", "detect_types", "enforce_foreign_keys"]:
223            print(f"{property_}: {getattr(self, property_)}")

See current database property settings.

@time_it()
def do_query(self, query: str):
225    @time_it()
226    def do_query(self, query: str):
227        """Execute a query against the current database."""
228        print(f"Executing {query}")
229        with self._DB() as db:
230            results = db.query(query)
231        self.display(results)
232        print(f"{db.cursor.rowcount} affected rows")

Execute a query against the current database.

@argshell.with_parser(dbparsers.get_rename_column_parser)
def do_rename_column(self, args: argshell.argshell.Namespace):
234    @argshell.with_parser(dbparsers.get_rename_column_parser)
235    def do_rename_column(self, args: argshell.Namespace):
236        """Rename a column."""
237        with self._DB() as db:
238            db.rename_column(args.table, args.column, args.new_name)

Rename a column.

@argshell.with_parser(dbparsers.get_rename_table_parser)
def do_rename_table(self, args: argshell.argshell.Namespace):
240    @argshell.with_parser(dbparsers.get_rename_table_parser)
241    def do_rename_table(self, args: argshell.Namespace):
242        """Rename a table."""
243        with self._DB() as db:
244            db.rename_table(args.table, args.new_name)

Rename a table.

def do_restore(self, file: str):
246    def do_restore(self, file: str):
247        """Replace the current db file with the given db backup file."""
248        backup = Pathier(file.strip('"'))
249        if not backup.exists():
250            print(f"{backup} does not exist.")
251        else:
252            print(f"Restoring from {file}...")
253            self.dbpath.write_bytes(backup.read_bytes())
254            print("Restore complete.")

Replace the current db file with the given db backup file.

@argshell.with_parser(dbparsers.get_scan_dbs_parser)
def do_scan(self, args: argshell.argshell.Namespace):
256    @argshell.with_parser(dbparsers.get_scan_dbs_parser)
257    def do_scan(self, args: argshell.Namespace):
258        """Scan the current working directory for database files."""
259        dbs = self._scan(args.extensions, args.recursive)
260        for db in dbs:
261            print(db.separate(Pathier.cwd().stem))

Scan the current working directory for database files.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_schema(self, args: argshell.argshell.Namespace):
263    @argshell.with_parser(dbparsers.get_schema_parser)
264    @time_it()
265    def do_schema(self, args: argshell.Namespace):
266        """Print out the names of the database tables and views, their columns, and, optionally, the number of rows."""
267        self._show_tables(args)
268        self._show_views(args)

Print out the names of the database tables and views, their columns, and, optionally, the number of rows.

@time_it()
def do_script(self, path: str):
270    @time_it()
271    def do_script(self, path: str):
272        """Execute the given SQL script."""
273        with self._DB() as db:
274            self.display(db.execute_script(path))

Execute the given SQL script.

@argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
@time_it()
def do_select(self, args: argshell.argshell.Namespace):
276    @argshell.with_parser(dbparsers.get_select_parser, [dbparsers.select_post_parser])
277    @time_it()
278    def do_select(self, args: argshell.Namespace):
279        """Execute a SELECT query with the given args."""
280        print(f"Querying {args.table}... ")
281        with self._DB() as db:
282            rows = db.select(
283                table=args.table,
284                columns=args.columns,
285                joins=args.joins,
286                where=args.where,
287                group_by=args.group_by,
288                having=args.Having,
289                order_by=args.order_by,
290                limit=args.limit,
291                exclude_columns=args.exclude_columns,
292            )
293            print(f"Found {len(rows)} rows:")
294            self.display(rows)
295            print(f"{len(rows)} rows from {args.table}")

Execute a SELECT query with the given args.

def do_set_connection_timeout(self, seconds: str):
297    def do_set_connection_timeout(self, seconds: str):
298        """Set database connection timeout to this number of seconds."""
299        self.connection_timeout = float(seconds)

Set database connection timeout to this number of seconds.

def do_set_detect_types(self, should_detect: str):
301    def do_set_detect_types(self, should_detect: str):
302        """Pass a `1` to turn on and a `0` to turn off."""
303        self.detect_types = bool(int(should_detect))

Pass a 1 to turn on and a 0 to turn off.

def do_set_enforce_foreign_keys(self, should_enforce: str):
305    def do_set_enforce_foreign_keys(self, should_enforce: str):
306        """Pass a `1` to turn on and a `0` to turn off."""
307        self.enforce_foreign_keys = bool(int(should_enforce))

Pass a 1 to turn on and a 0 to turn off.

def do_size(self, _: str):
309    def do_size(self, _: str):
310        """Display the size of the the current db file."""
311        print(f"{self.dbpath.name} is {self.dbpath.formatted_size}.")

Display the size of the the current db file.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_tables(self, args: argshell.argshell.Namespace):
313    @argshell.with_parser(dbparsers.get_schema_parser)
314    @time_it()
315    def do_tables(self, args: argshell.Namespace):
316        """Print out the names of the database tables, their columns, and, optionally, the number of rows."""
317        self._show_tables(args)

Print out the names of the database tables, their columns, and, optionally, the number of rows.

@argshell.with_parser(dbparsers.get_update_parser)
@time_it()
def do_update(self, args: argshell.argshell.Namespace):
319    @argshell.with_parser(dbparsers.get_update_parser)
320    @time_it()
321    def do_update(self, args: argshell.Namespace):
322        """Update a column to a new value.
323
324        Syntax:
325        >>> update {table} {column} {value} {where}
326        >>> based>update users username big_chungus "username = lil_chungus"
327
328        ^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^
329        """
330        print("Updating rows...")
331        with self._DB() as db:
332            num_updates = db.update(args.table, args.column, args.new_value, args.where)
333            print(f"Updated {num_updates} rows in table {args.table}.")

Update a column to a new value.

Syntax:

>>> update {table} {column} {value} {where}
>>> based>update users username big_chungus "username = lil_chungus"

^will update the username in the users 'table' to 'big_chungus' where the username is currently 'lil_chungus'^

def do_use(self, dbname: str):
335    def do_use(self, dbname: str):
336        """Set which database file to use."""
337        dbpath = Pathier(dbname)
338        if not dbpath.exists():
339            print(f"{dbpath} does not exist.")
340            print(f"Still using {self.dbpath}")
341        elif not dbpath.is_file():
342            print(f"{dbpath} is not a file.")
343            print(f"Still using {self.dbpath}")
344        else:
345            self.dbpath = dbpath
346            self.prompt = f"{self.dbpath.name}>"

Set which database file to use.

@time_it()
def do_vacuum(self, _: str):
348    @time_it()
349    def do_vacuum(self, _: str):
350        """Reduce database disk memory."""
351        print(f"Database size before vacuuming: {self.dbpath.formatted_size}")
352        print("Vacuuming database...")
353        with self._DB() as db:
354            freedspace = db.vacuum()
355        print(f"Database size after vacuuming: {self.dbpath.formatted_size}")
356        print(f"Freed up {Pathier.format_bytes(freedspace)} of disk space.")

Reduce database disk memory.

@argshell.with_parser(dbparsers.get_schema_parser)
@time_it()
def do_views(self, args: argshell.argshell.Namespace):
358    @argshell.with_parser(dbparsers.get_schema_parser)
359    @time_it()
360    def do_views(self, args: argshell.Namespace):
361        """Print out the names of the database views, their columns, and, optionally, the number of rows."""
362        self._show_views(args)

Print out the names of the database views, their columns, and, optionally, the number of rows.

def preloop(self):
396    def preloop(self):
397        """Scan the current directory for a .db file to use.
398        If not found, prompt the user for one or to try again recursively."""
399        if self.dbpath:
400            self.dbpath = Pathier(self.dbpath)
401            print(f"Defaulting to database {self.dbpath}")
402        else:
403            print("Searching for database...")
404            cwd = Pathier.cwd()
405            dbs = self._scan()
406            if len(dbs) == 1:
407                self.dbpath = dbs[0]
408                print(f"Using database {self.dbpath}.")
409            elif dbs:
410                self.dbpath = self._choose_db(dbs)
411            else:
412                print(f"Could not find a .db file in {cwd}.")
413                path = input(
414                    "Enter path to .db file to use or press enter to search again recursively: "
415                )
416                if path:
417                    self.dbpath = Pathier(path)
418                elif not path:
419                    print("Searching recursively...")
420                    dbs = self._scan(recursive=True)
421                    if len(dbs) == 1:
422                        self.dbpath = dbs[0]
423                        print(f"Using database {self.dbpath}.")
424                    elif dbs:
425                        self.dbpath = self._choose_db(dbs)
426                    else:
427                        print("Could not find a .db file.")
428                        self.dbpath = Pathier(input("Enter path to a .db file: "))
429        if not self.dbpath.exists():
430            raise FileNotFoundError(f"{self.dbpath} does not exist.")
431        if not self.dbpath.is_file():
432            raise ValueError(f"{self.dbpath} is not a file.")

Scan the current directory for a .db file to use. If not found, prompt the user for one or to try again recursively.

Inherited Members
cmd.Cmd
Cmd
precmd
postcmd
postloop
parseline
onecmd
completedefault
completenames
complete
get_names
complete_help
print_topics
columnize
argshell.argshell.ArgShell
do_quit
do_sys
cmdloop
emptyline
def get_args() -> argparse.Namespace:
435def get_args() -> argparse.Namespace:
436    parser = argparse.ArgumentParser()
437
438    parser.add_argument(
439        "dbpath",
440        nargs="?",
441        type=str,
442        help=""" The database file to use. If not provided the current working directory will be scanned for database files. """,
443    )
444    args = parser.parse_args()
445
446    return args
def main(args: argparse.Namespace | None = None):
449def main(args: argparse.Namespace | None = None):
450    if not args:
451        args = get_args()
452    dbshell = DBShell()
453    if args.dbpath:
454        dbshell.dbpath = Pathier(args.dbpath)
455    dbshell.cmdloop()