196 lines
6.5 KiB
Python
196 lines
6.5 KiB
Python
"""
|
|
UniFi DNS policy updater.
|
|
|
|
Updates A and AAAA records in UniFi DNS policies using an API token.
|
|
"""
|
|
|
|
import logging
|
|
from typing import TypedDict
|
|
|
|
import requests
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UnifiRecordType(TypedDict):
|
|
record: str
|
|
ttl_seconds: int
|
|
skip_ipv4: bool | None
|
|
skip_ipv6: bool | None
|
|
|
|
|
|
class UnifiConfig(TypedDict):
|
|
host: str
|
|
site_id: str
|
|
api_token: str
|
|
verify_ssl: bool
|
|
records: list[UnifiRecordType]
|
|
|
|
|
|
def _get_session(
|
|
base_url: str,
|
|
api_token: str,
|
|
verify_ssl: bool = False,
|
|
) -> requests.Session:
|
|
logger.debug("Creating UniFi session: host=%s, verify_ssl=%s", base_url, verify_ssl)
|
|
session = requests.Session()
|
|
session.verify = verify_ssl
|
|
session.headers.update({
|
|
"X-API-Key": api_token,
|
|
})
|
|
logger.debug("Session created with X-CSRF-Token header")
|
|
return session
|
|
|
|
|
|
def list_dns_policies(session: requests.Session, api_base: str, site_id: str) -> list[dict]:
|
|
url = f"{api_base}/sites/{site_id}/dns/policies"
|
|
logger.debug("Fetching DNS policies from %s", url)
|
|
response = session.get(url, verify=session.verify)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
policies = data.get("data", [])
|
|
logger.info("Fetched %d existing DNS policy/policies from UniFi", len(policies))
|
|
result = []
|
|
for policy in policies:
|
|
result.append({
|
|
"id": policy.get("id"),
|
|
"type": policy.get("type"),
|
|
"domain": policy.get("domain"),
|
|
"ipv4Address": policy.get("ipv4Address"),
|
|
"ipv6Address": policy.get("ipv6Address"),
|
|
"ttlSeconds": policy.get("ttlSeconds"),
|
|
"enabled": policy.get("enabled", True),
|
|
})
|
|
return result
|
|
|
|
|
|
def _get_policy_key(policy: dict) -> str:
|
|
domain = policy.get("domain", "")
|
|
ptype = policy.get("type", "")
|
|
return f"{domain}:{ptype}"
|
|
|
|
|
|
def _get_policy_map(policies: list[dict]) -> dict[str, dict]:
|
|
policy_map: dict[str, dict] = {}
|
|
for policy in policies:
|
|
key = _get_policy_key(policy)
|
|
if key:
|
|
policy_map[key] = policy
|
|
return policy_map
|
|
|
|
|
|
def _create_or_update_policy(
|
|
session: requests.Session,
|
|
api_base: str,
|
|
site_id: str,
|
|
record_type: str,
|
|
domain: str,
|
|
ip_address: str,
|
|
ttl_seconds: int,
|
|
existing_policy: dict | None,
|
|
) -> bool:
|
|
payload: dict = {
|
|
"type": record_type,
|
|
"enabled": True,
|
|
"domain": domain,
|
|
"ttlSeconds": ttl_seconds,
|
|
}
|
|
|
|
if record_type == "A_RECORD":
|
|
payload["ipv4Address"] = ip_address
|
|
elif record_type == "AAAA_RECORD":
|
|
payload["ipv6Address"] = ip_address
|
|
|
|
logger.debug("Payload for %s on %s: %s", record_type, domain, payload)
|
|
|
|
if existing_policy and existing_policy.get("id"):
|
|
current_ip = existing_policy.get("ipv4Address") or existing_policy.get("ipv6Address")
|
|
if current_ip == ip_address:
|
|
logger.info("%s policy for %s is already %s, no update needed", record_type, domain, ip_address)
|
|
return False
|
|
policy_id = existing_policy["id"]
|
|
logger.info("Updating existing %s policy for %s (id=%s, current_ip=%s)",
|
|
record_type, domain, policy_id, current_ip)
|
|
url = f"{api_base}/sites/{site_id}/dns/policies/{policy_id}"
|
|
logger.debug("Sending PUT to %s", url)
|
|
response = session.put(url, json=payload, verify=session.verify)
|
|
else:
|
|
logger.info("Creating new %s policy for %s", record_type, domain)
|
|
url = f"{api_base}/sites/{site_id}/dns/policies"
|
|
logger.debug("Sending POST to %s", url)
|
|
response = session.post(url, json=payload, verify=session.verify)
|
|
|
|
response.raise_for_status()
|
|
logger.info("Successfully updated %s policy for %s -> %s", record_type, domain, ip_address)
|
|
return True
|
|
|
|
|
|
def update_records(
|
|
unifi_config: UnifiConfig,
|
|
ipv4: str | None = None,
|
|
ipv6: str | None = None,
|
|
) -> tuple[set[str], set[str]]:
|
|
base_url = unifi_config["host"]
|
|
site_id = unifi_config["site_id"]
|
|
api_token = unifi_config["api_token"]
|
|
verify_ssl = unifi_config.get("verify_ssl", False)
|
|
records = unifi_config["records"]
|
|
|
|
logger.info("Connecting to UniFi controller: %s (site=%s)", base_url, site_id)
|
|
api_base = f"{base_url.rstrip('/')}/proxy/network/integration/v1"
|
|
session = _get_session(base_url, api_token, verify_ssl)
|
|
|
|
policies = list_dns_policies(session, api_base, site_id)
|
|
policy_map = _get_policy_map(policies)
|
|
|
|
updated_ipv4: set[str] = set()
|
|
updated_ipv6: set[str] = set()
|
|
|
|
for record in records:
|
|
domain = record["record"]
|
|
ttl = record.get("ttl_seconds", 14400)
|
|
logger.info("=== Processing UniFi record: %s (ttl=%s) ===", domain, ttl)
|
|
|
|
if ipv4 and not record.get("skip_ipv4"):
|
|
existing = policy_map.get(f"{domain}:A_RECORD")
|
|
if existing:
|
|
logger.debug("Found existing A_RECORD policy for %s: id=%s, ip=%s",
|
|
domain, existing["id"], existing.get("ipv4Address"))
|
|
else:
|
|
logger.debug("No existing A_RECORD policy for %s, will create new", domain)
|
|
changed = _create_or_update_policy(
|
|
session, api_base, site_id,
|
|
"A_RECORD", domain, ipv4, ttl,
|
|
existing,
|
|
)
|
|
if changed:
|
|
updated_ipv4.add(domain)
|
|
elif ipv4 and record.get("skip_ipv4"):
|
|
logger.info("Skipping IPv4 for %s (skip_ipv4=true)", domain)
|
|
|
|
if ipv6 and not record.get("skip_ipv6"):
|
|
existing = policy_map.get(f"{domain}:AAAA_RECORD")
|
|
if existing:
|
|
logger.debug("Found existing AAAA_RECORD policy for %s: id=%s, ip=%s",
|
|
domain, existing["id"], existing.get("ipv6Address"))
|
|
else:
|
|
logger.debug("No existing AAAA_RECORD policy for %s, will create new", domain)
|
|
changed = _create_or_update_policy(
|
|
session, api_base, site_id,
|
|
"AAAA_RECORD", domain, ipv6, ttl,
|
|
existing,
|
|
)
|
|
if changed:
|
|
updated_ipv6.add(domain)
|
|
elif ipv6 and record.get("skip_ipv6"):
|
|
logger.info("Skipping IPv6 for %s (skip_ipv6=true)", domain)
|
|
|
|
logger.info("=== Done processing UniFi record: %s ===", domain)
|
|
|
|
return updated_ipv4, updated_ipv6
|