205 lines
6.4 KiB
Python
205 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
"""MS SQL client tool for ScadaLink test infrastructure."""
|
|
|
|
import argparse
|
|
import sys
|
|
|
|
import pymssql
|
|
|
|
|
|
DEFAULT_HOST = "localhost"
|
|
DEFAULT_PORT = 1433
|
|
DEFAULT_USER = "sa"
|
|
DEFAULT_PASSWORD = "ScadaLink_Dev1#"
|
|
EXPECTED_DBS = ["ScadaLinkConfig", "ScadaLinkMachineData"]
|
|
|
|
|
|
def get_connection(args, database=None):
|
|
"""Create and return a database connection."""
|
|
return pymssql.connect(
|
|
server=args.host,
|
|
port=args.port,
|
|
user=args.user,
|
|
password=args.password,
|
|
database=database or "master",
|
|
)
|
|
|
|
|
|
def cmd_check(args):
|
|
"""Connect and list databases, verify expected DBs exist."""
|
|
try:
|
|
conn = get_connection(args)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sys.databases ORDER BY name")
|
|
databases = [row[0] for row in cursor.fetchall()]
|
|
|
|
print(f"Connected to: {args.host}:{args.port}")
|
|
print(f"Databases ({len(databases)}):")
|
|
for db in databases:
|
|
marker = " <-- expected" if db in EXPECTED_DBS else ""
|
|
print(f" {db}{marker}")
|
|
|
|
missing = [db for db in EXPECTED_DBS if db not in databases]
|
|
if missing:
|
|
print(f"\nMissing expected databases: {', '.join(missing)}")
|
|
print("Run: python infra/tools/mssql_tool.py setup --script infra/mssql/setup.sql")
|
|
sys.exit(1)
|
|
else:
|
|
print("\nAll expected databases present.")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def cmd_setup(args):
|
|
"""Execute a SQL script file."""
|
|
try:
|
|
with open(args.script, "r") as f:
|
|
sql = f.read()
|
|
except FileNotFoundError:
|
|
print(f"Error: script not found: {args.script}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
try:
|
|
conn = get_connection(args)
|
|
conn.autocommit(True)
|
|
cursor = conn.cursor()
|
|
|
|
# Split on GO statements (SQL Server batch separator)
|
|
batches = []
|
|
current = []
|
|
for line in sql.splitlines():
|
|
if line.strip().upper() == "GO":
|
|
if current:
|
|
batches.append("\n".join(current))
|
|
current = []
|
|
else:
|
|
current.append(line)
|
|
if current:
|
|
batches.append("\n".join(current))
|
|
|
|
for i, batch in enumerate(batches, 1):
|
|
batch = batch.strip()
|
|
if not batch:
|
|
continue
|
|
cursor.execute(batch)
|
|
print(f" Batch {i} executed.")
|
|
|
|
print(f"\nScript completed: {args.script} ({len(batches)} batches)")
|
|
cursor.close()
|
|
conn.close()
|
|
except Exception as e:
|
|
print(f"Error executing script: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def cmd_query(args):
|
|
"""Run an ad-hoc SQL query and print results."""
|
|
try:
|
|
conn = get_connection(args, database=args.database)
|
|
cursor = conn.cursor()
|
|
cursor.execute(args.sql)
|
|
|
|
if cursor.description:
|
|
columns = [desc[0] for desc in cursor.description]
|
|
rows = cursor.fetchall()
|
|
|
|
# Calculate column widths
|
|
widths = [len(c) for c in columns]
|
|
for row in rows:
|
|
for i, val in enumerate(row):
|
|
widths[i] = max(widths[i], len(str(val)))
|
|
|
|
# Print header
|
|
header = " ".join(c.ljust(w) for c, w in zip(columns, widths))
|
|
print(header)
|
|
print(" ".join("-" * w for w in widths))
|
|
|
|
# Print rows
|
|
for row in rows:
|
|
print(" ".join(str(v).ljust(w) for v, w in zip(row, widths)))
|
|
|
|
print(f"\n({len(rows)} rows)")
|
|
else:
|
|
conn.commit()
|
|
print(f"Query executed. Rows affected: {cursor.rowcount}")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def cmd_tables(args):
|
|
"""List all tables in a database."""
|
|
try:
|
|
conn = get_connection(args, database=args.database)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
SELECT s.name AS [schema], t.name AS [table],
|
|
SUM(p.rows) AS [rows]
|
|
FROM sys.tables t
|
|
JOIN sys.schemas s ON t.schema_id = s.schema_id
|
|
JOIN sys.partitions p ON t.object_id = p.object_id AND p.index_id IN (0, 1)
|
|
GROUP BY s.name, t.name
|
|
ORDER BY s.name, t.name
|
|
""")
|
|
rows = cursor.fetchall()
|
|
|
|
if not rows:
|
|
print(f"No tables in {args.database}.")
|
|
else:
|
|
print(f"Tables in {args.database}:")
|
|
print(f"{'Schema':<15} {'Table':<40} {'Rows':<10}")
|
|
print("-" * 65)
|
|
for schema, table, count in rows:
|
|
print(f"{schema:<15} {table:<40} {count:<10}")
|
|
print(f"\n({len(rows)} tables)")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
except Exception as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MS SQL client tool for ScadaLink test infrastructure")
|
|
parser.add_argument("--host", default=DEFAULT_HOST, help=f"SQL Server host (default: {DEFAULT_HOST})")
|
|
parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"Port (default: {DEFAULT_PORT})")
|
|
parser.add_argument("--user", default=DEFAULT_USER, help=f"Username (default: {DEFAULT_USER})")
|
|
parser.add_argument("--password", default=DEFAULT_PASSWORD, help=f"Password (default: {DEFAULT_PASSWORD})")
|
|
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
sub.add_parser("check", help="Connect and verify expected databases")
|
|
|
|
setup_p = sub.add_parser("setup", help="Execute a SQL script")
|
|
setup_p.add_argument("--script", required=True, help="Path to SQL script file")
|
|
|
|
query_p = sub.add_parser("query", help="Run an ad-hoc SQL query")
|
|
query_p.add_argument("--database", required=True, help="Database name")
|
|
query_p.add_argument("--sql", required=True, help="SQL query to execute")
|
|
|
|
tables_p = sub.add_parser("tables", help="List all tables in a database")
|
|
tables_p.add_argument("--database", required=True, help="Database name")
|
|
|
|
args = parser.parse_args()
|
|
|
|
commands = {
|
|
"check": cmd_check,
|
|
"setup": cmd_setup,
|
|
"query": cmd_query,
|
|
"tables": cmd_tables,
|
|
}
|
|
|
|
commands[args.command](args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|