mirror of
https://github.com/open-metadata/OpenMetadata
synced 2026-05-24 09:39:11 +00:00
feat(ingestion): introduce TagRegistry domain layer (#27991)
* feat(ingestion): introduce TagRegistry domain layer
Adds metadata.domain.tags with TagRegistry (per-Source bookkeeping) and
TagCanonicalizer (case-corrected name resolution against OM). Migrates
the Snowflake connector to the new architecture; other connectors stay
on the legacy context.tags flow (strangler pattern).
TagRegistry interns shared TagLabel instances by (classification, tag,
label_type, state) and rebinds the per-entity dict on scope clear. On a
schema with ~120k tag attachments and 21 unique tags, peak heap drops
from ~112 MB to ~24 MB.
Public get_*_tag_labels methods on the database service base are
unchanged; non-Snowflake DB connectors are not touched.
* fix(ingestion): tighten TagRegistry contract and address PR/CI feedback
- Drop None defaults from `TagRegistry.attach()` description params
(required `str | None`); normalize None -> "" inside
`_build_pending_record` so the OM schema's required-non-null Markdown
contract is owned at the registry boundary, not at every caller.
- Fix `yield_database_tag` missing registry drain (PR review).
- Fix `clear_scope` lock-order race that left labels visible after the
scope was marked cleared (PR review).
- Resolve basedpyright errors:
- `Either` and `TopologyNode` use Form 3 Pydantic v2 fields
(`Annotated[X, Field(...)] = default`) so static checkers see the
defaults.
- `cast("str", fqn.build(...))` at the 13 sites that feed
`fqn.build(...)` results into FQN-typed args.
- Scoped `# pyright: ignore[reportAttributeAccessIssue]` on
`TopologyContext` dynamic-attribute accesses (matches the codebase
pattern of grandfathering 8.4k+ such errors via baseline).
- Populate `stackTrace` on the three snowflake tag-error
`StackTraceError`s so the Status UI shows the trace, not just the
one-line summary.
- Rewrite three snowflake tag-inheritance tests to drive the real
registry attach + inheritance walk after `get_tag_label` was removed:
- `test_schema_tag_inheritance`
- `test_database_tag_inheritance`
- `test_tag_value_precedence` (one attach intentionally passes
`None` descriptions to exercise the registry's normalization)
* fix(ingestion): revert Either to Form 2 to avoid Generic-T inference cascade
Reverting `Either.left`/`Either.right` to the original
`Annotated[Optional[T], Field(default=None, ...)]` form. The Form 3
shape (`Annotated[T | None, Field(...)] = None`) introduced in the
prior commit caused pyright to eagerly bind `T` from the literal-None
default at every no-arg construction site like `Either(left=...)`,
resolving them to `Either[Unknown]`. Because `Either[T]` is invariant,
those failed to satisfy declared generator return types like
`Iterable[Either[CreateTableRequest]]` — surfacing 45 latent
reportReturnType errors across sample_data, dbt, sas, qliksense, sigma,
common_db_source, common_broker_source, amundsen, sink/metadata_rest.
Form 2 wraps the default inside `Annotated` metadata where pyright
treats it as opaque: it sees "this field has a default" (so construction
sites pass) but doesn't eagerly bind `T` from None. Context-driven
inference works, no cascade.
`TopologyNode` stays in Form 3 — its fields are concrete-typed
(`list[str] | None`, `bool`), no Generic-T to bind.
* fix(ingestion): close attach race, thread-safe lazy registries, satisfy CI pyright
Three related fixes from PR review + CI:
1. **TagRegistry.attach race fully closed (Copilot review).** The cleared-scope
check and the labels-mutation now happen under the same ``_scope_state_lock``,
so a concurrent ``clear_scope`` cannot interleave between the check and the
write. Moved ``_cleared_scopes`` ownership from ``_run_state_lock`` to
``_scope_state_lock`` accordingly. ``clear_scope`` now adds to
``_cleared_scopes`` and rebuilds ``_labels_by_entity`` atomically under the
single lock.
2. **Thread-safe lazy initialization of ``tags_registry`` and
``tag_canonicalizer`` (Copilot review).** ``functools.cached_property`` is
not thread-safe in Python 3.12+ — under parallel ``databaseSchema`` workers,
two threads could each construct their own registry and the second would
overwrite the first in ``__dict__``, losing tag-attaches that the orphaned
instance already received. Switched to ``@property`` + ``vars(self).setdefault``,
which is documented atomic under the GIL (Python 3.14 Thread Safety Guarantees)
and survives PEP 703 free-threaded mode via per-dict locks. Verified
thread-safe against 30 racing threads in an ad-hoc check (1 distinct
instance observed, all threads converge).
3. **Explicit ``right=None`` / ``left=None`` on Either error/success yields.**
Pyright (CI strict mode) treats Either's ``Annotated[Optional[T], Field(default=None, ...)]``
defaults as invisible, flagging every ``Either(left=...)`` and
``Either(right=...)`` as missing the counterpart. Passing the other side
explicitly satisfies the call contract without changing runtime behavior.
* test(ingestion): patch time.sleep instead of tenacity.nap.time.sleep
The previous patch target ``tenacity.nap.time.sleep`` reaches into
tenacity's internal module structure — fragile under any tenacity
refactor that renames or relocates ``nap.py``. Patching ``time.sleep``
directly is equally effective (tenacity does ``import time`` and calls
``time.sleep`` dynamically, so attribute-level patching intercepts it)
and depends only on stdlib, which is stable.
Addresses Copilot review on test_canonicalizer.py:25.
* refactor(ingestion): single lock in TagRegistry, intern after cleared-check
Two refinements from a fresh round of Copilot review:
1. **Merge ``_run_state_lock`` and ``_scope_state_lock`` into a single
``_lock``.** Two locks were a smell: every ``attach()`` already needed
both, the lock-acquisition order varied across methods (``attach`` does
one order, ``stats`` the reverse — fragile if any future code held both
simultaneously), and ``RLock`` was used defensively without any actual
re-entry. One lock = single mental model, no ordering invariant, deadlock
impossible.
2. **Move ``_intern_tag_label`` invocation inside the cleared-scope
gate** (Copilot review). With one lock, the cleared-check, intern,
label-append, and pending-update all happen in one atomic critical
section — so when a scope is already cleared, no TagLabel is interned
into ``_tag_label_cache``. Verified deterministically: 30 threads
attaching to a pre-cleared scope all raise and the cache stays empty.
3. Reword the inline ``vars(self).setdefault`` thread-safety comment in
``database_service.py`` to drop the misleading Python 3.14 reference
(project targets 3.10+) — just point at the official
``threadsafety.html`` doc.
* refactor(ingestion): tighten description contract; rename canonicalizer params
End-to-end tightening to prevent the empty-description / overwrite path
flagged in PR review. Three coordinated changes:
1. **``Canonical.description`` is now ``str`` (required)** instead of
``str | None``. The canonicalizer always seeds with the caller-provided
default and only overrides with a non-empty server value, so the
resolved description is invariably a real string. Removing the
Optional makes that invariant visible at the type level.
2. **Rename canonicalizer parameters to make their fallback semantics
honest**:
- ``classification(name, description)`` →
``classification(name, default_description)``
- ``tag(classification_name, tag_name, tag_description)`` →
``tag(classification_name, tag_name, default_tag_description)``
The parameter is used to seed the Canonical when no system match
exists in OM, and as a fallback when an OM match has an empty
description — never as the value to set on an existing entity. The
rename signals "if I have to invent this, here's what to write down"
rather than "set the description to this." Snowflake call sites use
the new kwargs explicitly for self-documenting reads.
3. **Tighten ``TagRegistry.attach`` and ``_build_pending_record`` to
required ``str`` for both descriptions** (was ``str | None``). With
``Canonical.description`` now ``str``, every caller can pass through
without an ``or ""`` shim. ``_build_pending_record`` drops the
defensive ``Markdown(x or "")`` for clean ``Markdown(x)``.
Net effect: no possible code path sends ``None`` or relies on an
``or ""`` fallback that would clobber a server-side description. The
schema's required-Markdown contract is enforced at every layer above
the wire.
This commit is contained in:
parent
8ad28268e8
commit
c4d9a86804
13 changed files with 1346 additions and 227 deletions
|
|
@ -165,6 +165,7 @@ base_requirements = {
|
|||
"sqlalchemy>=2.0.0,<3",
|
||||
"collate-sqllineage>=2.1.1",
|
||||
"tabulate==0.9.0",
|
||||
"tenacity>=8.0,<10",
|
||||
"typing-inspect",
|
||||
"packaging", # For version parsing
|
||||
"setuptools>=78.1.1,<81", # <81 required: pkg_resources removed in setuptools 81+
|
||||
|
|
|
|||
22
ingestion/src/metadata/domain/__init__.py
Normal file
22
ingestion/src/metadata/domain/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""OpenMetadata domain utilities.
|
||||
|
||||
In-memory helpers operating on OpenMetadata's data model, reusable across
|
||||
service-source bases and features. A module belongs here when it satisfies
|
||||
ALL of:
|
||||
|
||||
1. Knows OM concepts (operates on OM-generated types or OM-specific ideas).
|
||||
2. Owns no I/O infrastructure. May use an INJECTED OM client for read-only
|
||||
queries; the client's lifecycle is the caller's.
|
||||
3. Framework-independent — no topology, stages, or sinks.
|
||||
4. Cross-cutting — used by more than one service-source base or feature.
|
||||
"""
|
||||
21
ingestion/src/metadata/domain/tags/__init__.py
Normal file
21
ingestion/src/metadata/domain/tags/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tag and Classification domain utilities."""
|
||||
|
||||
from metadata.domain.tags.canonicalizer import Canonical, TagCanonicalizer
|
||||
from metadata.domain.tags.registry import ScopeAlreadyClearedError, TagRegistry
|
||||
|
||||
__all__ = [
|
||||
"Canonical",
|
||||
"ScopeAlreadyClearedError",
|
||||
"TagCanonicalizer",
|
||||
"TagRegistry",
|
||||
]
|
||||
145
ingestion/src/metadata/domain/tags/canonicalizer.py
Normal file
145
ingestion/src/metadata/domain/tags/canonicalizer.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TagCanonicalizer — case-corrected name resolution against OpenMetadata.
|
||||
|
||||
Resolves source-system Classification and Tag names to the canonical form
|
||||
of any matching system-provider entity in OM (e.g., source reports
|
||||
``pii.sensitive`` → returns ``PII.Sensitive``). Persistent ES failures
|
||||
raise after retry exhaustion.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metadata.generated.schema.entity.classification.classification import Classification
|
||||
from metadata.generated.schema.entity.classification.tag import Tag
|
||||
from metadata.generated.schema.type.basic import ProviderType
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
_es_retry = retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_random_exponential(multiplier=2, max=30),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
class Canonical(NamedTuple):
|
||||
"""Canonical (name, description) pair returned from OpenMetadata."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class TagCanonicalizer:
|
||||
"""Case-corrected name resolution for system Classifications and Tags.
|
||||
|
||||
Persistent ES failures raise; callers should wrap in ``Either`` to
|
||||
surface them to workflow status.
|
||||
"""
|
||||
|
||||
def __init__(self, metadata: OpenMetadata) -> None:
|
||||
self._metadata = metadata
|
||||
self._classification_cache: dict[str, Canonical] = {}
|
||||
self._tag_cache: dict[str, Canonical] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def classification(
|
||||
self,
|
||||
name: str,
|
||||
default_description: str,
|
||||
) -> Canonical:
|
||||
"""Return canonical classification name + description from OM, cached.
|
||||
|
||||
``default_description`` is used to seed the Canonical when no
|
||||
system-provider match exists in OM, and as a fallback when an
|
||||
OM match has an empty description. An OM-side description wins
|
||||
over the default whenever available.
|
||||
"""
|
||||
key = name.lower()
|
||||
with self._lock:
|
||||
cached = self._classification_cache.get(key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
results = self._es_search(Classification, name)
|
||||
canonical = Canonical(name=name, description=default_description)
|
||||
for entity in results:
|
||||
if entity.provider == ProviderType.system and entity.name.root.lower() == key:
|
||||
canonical = Canonical(
|
||||
name=entity.name.root,
|
||||
description=entity.description.root if entity.description else default_description,
|
||||
)
|
||||
break
|
||||
|
||||
with self._lock:
|
||||
self._classification_cache.setdefault(key, canonical)
|
||||
return canonical
|
||||
|
||||
def tag(
|
||||
self,
|
||||
classification_name: str,
|
||||
tag_name: str,
|
||||
default_tag_description: str,
|
||||
) -> Canonical:
|
||||
"""Return canonical tag name + description from OM, cached.
|
||||
|
||||
``classification_name`` must already be canonical (call ``classification`` first).
|
||||
``default_tag_description`` is used to seed the Canonical when no
|
||||
system-provider match exists in OM, and as a fallback when an
|
||||
OM match has an empty description.
|
||||
"""
|
||||
tag_fqn = cast(
|
||||
"str",
|
||||
fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name),
|
||||
)
|
||||
key = tag_fqn.lower()
|
||||
with self._lock:
|
||||
cached = self._tag_cache.get(key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
results = self._es_search(Tag, tag_fqn)
|
||||
canonical = Canonical(name=tag_name, description=default_tag_description)
|
||||
for entity in results:
|
||||
if (
|
||||
entity.provider == ProviderType.system
|
||||
and entity.classification.name == classification_name
|
||||
and entity.name.root.lower() == tag_name.lower()
|
||||
):
|
||||
canonical = Canonical(
|
||||
name=entity.name.root,
|
||||
description=entity.description.root if entity.description else default_tag_description,
|
||||
)
|
||||
break
|
||||
|
||||
with self._lock:
|
||||
self._tag_cache.setdefault(key, canonical)
|
||||
return canonical
|
||||
|
||||
@_es_retry
|
||||
def _es_search(self, entity_type: Any, search_string: str) -> Iterable[Any]:
|
||||
"""Run an ES search by FQN with retries."""
|
||||
return self._metadata.es_search_from_fqn(entity_type=entity_type, fqn_search_string=search_string) or []
|
||||
235
ingestion/src/metadata/domain/tags/registry.py
Normal file
235
ingestion/src/metadata/domain/tags/registry.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TagRegistry — per-Source bookkeeping for Tag and Classification ingestion.
|
||||
|
||||
Holds two concerns:
|
||||
|
||||
* a queue of classification/tag create-payloads bound for the sink
|
||||
(deduped by FQN, drained per scope), and
|
||||
* a per-entity-FQN lookup of ``TagLabel`` instances for inheritance
|
||||
reads, dropped at scope boundaries.
|
||||
|
||||
Dedup is case-sensitive, matching OpenMetadata's tag-identity rule.
|
||||
Safe for concurrent use across the topology's parallel schema workers.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import NamedTuple, cast
|
||||
|
||||
from metadata.generated.schema.api.classification.createClassification import (
|
||||
CreateClassificationRequest,
|
||||
)
|
||||
from metadata.generated.schema.api.classification.createTag import CreateTagRequest
|
||||
from metadata.generated.schema.entity.classification.tag import Tag
|
||||
from metadata.generated.schema.type.basic import (
|
||||
EntityName,
|
||||
FullyQualifiedEntityName,
|
||||
Markdown,
|
||||
)
|
||||
from metadata.generated.schema.type.tagLabel import (
|
||||
LabelType,
|
||||
State,
|
||||
TagFQN,
|
||||
TagLabel,
|
||||
TagSource,
|
||||
)
|
||||
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
|
||||
from metadata.ingestion.ometa.ometa_api import OpenMetadata
|
||||
from metadata.ingestion.ometa.utils import model_str
|
||||
from metadata.utils import fqn
|
||||
from metadata.utils.logger import ingestion_logger
|
||||
|
||||
logger = ingestion_logger()
|
||||
|
||||
|
||||
class _TagLabelKey(NamedTuple):
|
||||
"""Identity tuple for the TagLabel cache."""
|
||||
|
||||
classification_name: str
|
||||
tag_name: str
|
||||
label_type: LabelType
|
||||
state: State
|
||||
|
||||
|
||||
class ScopeAlreadyClearedError(RuntimeError):
|
||||
"""Raised when 'attach' is called for a previously cleared scope.
|
||||
|
||||
Surfaces topology lifecycle bug loudly rather than silently re-creating a cleared scope.
|
||||
"""
|
||||
|
||||
|
||||
class TagRegistry:
|
||||
"""Registry for Tag and Classification ingestion bookkeeping."""
|
||||
|
||||
def __init__(self, metadata: OpenMetadata) -> None:
|
||||
self._metadata = metadata
|
||||
|
||||
self._known_tag_fqns: set[str] = set()
|
||||
self._tag_label_cache: dict[_TagLabelKey, TagLabel] = {}
|
||||
self._pending: list[OMetaTagAndClassification] = []
|
||||
self._cleared_scopes: set[str] = set()
|
||||
self._labels_by_entity: dict[str, list[TagLabel]] = {}
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _intern_tag_label_locked(
|
||||
self, *, classification_name: str, tag_name: str, label_type: LabelType, state: State
|
||||
) -> TagLabel:
|
||||
"""Return the shared ``TagLabel`` for the given key. Caller must hold ``self._lock``."""
|
||||
key = _TagLabelKey(classification_name, tag_name, label_type, state)
|
||||
cached = self._tag_label_cache.get(key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
tag_fqn = cast("str", fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name))
|
||||
cached = TagLabel( # pyright: ignore[reportCallIssue]
|
||||
tagFQN=TagFQN(tag_fqn),
|
||||
labelType=label_type,
|
||||
state=state,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
self._tag_label_cache[key] = cached
|
||||
return cached
|
||||
|
||||
def attach(
|
||||
self,
|
||||
*,
|
||||
scope_fqn: str,
|
||||
entity_fqn: str,
|
||||
classification_name: str,
|
||||
tag_name: str,
|
||||
classification_description: str,
|
||||
tag_description: str,
|
||||
label_type: LabelType = LabelType.Automated,
|
||||
state: State = State.Suggested,
|
||||
) -> None:
|
||||
"""Register a tag <-> entity association."""
|
||||
if not tag_name or not tag_name.strip():
|
||||
logger.debug("TagRegistry: skipping empty tag for classification %s", classification_name)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
if scope_fqn in self._cleared_scopes:
|
||||
raise ScopeAlreadyClearedError(
|
||||
f"Tag attach called for cleared scope '{scope_fqn!r}' for entity '{entity_fqn!r}'"
|
||||
)
|
||||
tag_label = self._intern_tag_label_locked(
|
||||
classification_name=classification_name,
|
||||
tag_name=tag_name,
|
||||
label_type=label_type,
|
||||
state=state,
|
||||
)
|
||||
self._labels_by_entity.setdefault(entity_fqn, []).append(tag_label)
|
||||
|
||||
tag_fqn = model_str(tag_label.tagFQN)
|
||||
if tag_fqn not in self._known_tag_fqns:
|
||||
self._known_tag_fqns.add(tag_fqn)
|
||||
self._pending.append(
|
||||
self._build_pending_record(
|
||||
classification_name=classification_name,
|
||||
classification_description=classification_description,
|
||||
tag_name=tag_name,
|
||||
tag_description=tag_description,
|
||||
)
|
||||
)
|
||||
|
||||
def labels_for(self, entity_fqn: str) -> list[TagLabel]:
|
||||
"""Return tag labels attached to ``entity_fqn`` (idempotent; returns a copy)."""
|
||||
with self._lock:
|
||||
return list(self._labels_by_entity.get(entity_fqn, []))
|
||||
|
||||
def drain(self) -> Iterable[OMetaTagAndClassification]:
|
||||
"""Yield all queued create payloads and clear the queue."""
|
||||
with self._lock:
|
||||
pending, self._pending = self._pending, []
|
||||
|
||||
if pending:
|
||||
logger.debug("TagRegistry: drained %d pending tag payloads.", len(pending))
|
||||
yield from pending
|
||||
|
||||
def clear_scope(self, scope_fqn: str) -> None:
|
||||
"""Drop labels under ``scope_fqn`` and mark the scope cleared.
|
||||
|
||||
Subsequent ``attach`` calls for this scope will raise.
|
||||
"""
|
||||
prefix = scope_fqn + fqn.FQN_SEPARATOR
|
||||
|
||||
with self._lock:
|
||||
self._cleared_scopes.add(scope_fqn)
|
||||
kept = {k: v for k, v in self._labels_by_entity.items() if k != scope_fqn and not k.startswith(prefix)}
|
||||
dropped = len(self._labels_by_entity) - len(kept)
|
||||
self._labels_by_entity = kept
|
||||
if dropped:
|
||||
logger.debug("TagRegistry: cleared scope %s (%d entity labels dropped)", scope_fqn, dropped)
|
||||
|
||||
def is_known(self, tag_fqn: str) -> bool:
|
||||
"""Return True if the tag FQN has been recorded (case-sensitive match)."""
|
||||
with self._lock:
|
||||
return tag_fqn in self._known_tag_fqns
|
||||
|
||||
def ensure_known(self, tag_fqn: str) -> bool:
|
||||
"""Return True if the tag exists server-side, caching positive results.
|
||||
|
||||
Returns False (and does NOT cache) on 404 or transport error.
|
||||
"""
|
||||
if self.is_known(tag_fqn):
|
||||
return True
|
||||
|
||||
logger.debug("TagRegistry: cache miss for %s; fetching from OpenMetadata.", tag_fqn)
|
||||
try:
|
||||
entity = self._metadata.get_by_name(entity=Tag, fqn=tag_fqn)
|
||||
except Exception:
|
||||
logger.exception("TagRegistry: tag lookup failed for %s.", tag_fqn)
|
||||
return False
|
||||
|
||||
if entity is None:
|
||||
logger.warning(
|
||||
"TagRegistry: tag %s not found in OpenMetadata; labels referencing it will be skipped.", tag_fqn
|
||||
)
|
||||
return False
|
||||
|
||||
with self._lock:
|
||||
self._known_tag_fqns.add(tag_fqn)
|
||||
return True
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""Return current state counts for instrumentation."""
|
||||
with self._lock:
|
||||
return {
|
||||
"known_tag_fqns": len(self._known_tag_fqns),
|
||||
"tag_label_cache": len(self._tag_label_cache),
|
||||
"pending": len(self._pending),
|
||||
"cleared_scopes": len(self._cleared_scopes),
|
||||
"live_entities": len(self._labels_by_entity),
|
||||
"live_labels": sum(len(v) for v in self._labels_by_entity.values()),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_pending_record(
|
||||
*,
|
||||
classification_name: str,
|
||||
classification_description: str,
|
||||
tag_name: str,
|
||||
tag_description: str,
|
||||
) -> OMetaTagAndClassification:
|
||||
"""Compose the sink-bound create-payload for a classification + tag."""
|
||||
return OMetaTagAndClassification(
|
||||
fqn=None,
|
||||
classification_request=CreateClassificationRequest( # pyright: ignore[reportCallIssue]
|
||||
name=EntityName(classification_name),
|
||||
description=Markdown(classification_description),
|
||||
),
|
||||
tag_request=CreateTagRequest( # pyright: ignore[reportCallIssue]
|
||||
classification=FullyQualifiedEntityName(classification_name),
|
||||
name=EntityName(tag_name),
|
||||
description=Markdown(tag_description),
|
||||
),
|
||||
)
|
||||
|
|
@ -15,7 +15,7 @@ Defines the topology for ingesting sources
|
|||
import queue
|
||||
import threading
|
||||
from functools import cache, singledispatchmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035
|
||||
from typing import Annotated, Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
|
||||
|
|
@ -111,14 +111,18 @@ class TopologyNode(BaseModel):
|
|||
"Each stage accepts the producer results as an argument"
|
||||
),
|
||||
)
|
||||
children: Optional[List[str]] = Field(None, description="Nodes to execute next") # noqa: UP006, UP045
|
||||
post_process: Optional[List[str]] = Field( # noqa: UP006, UP045
|
||||
None, description="Method to be run after the node has been fully processed"
|
||||
)
|
||||
threads: bool = Field(
|
||||
False,
|
||||
description="Flag that defines if a node is open to MultiThreading processing.",
|
||||
)
|
||||
children: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Nodes to execute next"),
|
||||
] = None
|
||||
post_process: Annotated[
|
||||
list[str] | None,
|
||||
Field(description="Method to be run after the node has been fully processed"),
|
||||
] = None
|
||||
threads: Annotated[
|
||||
bool,
|
||||
Field(description="Flag that defines if a node is open to MultiThreading processing."),
|
||||
] = False
|
||||
|
||||
|
||||
class ServiceTopology(BaseModel):
|
||||
|
|
|
|||
|
|
@ -14,12 +14,13 @@ Base class for ingesting database services
|
|||
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional, Set, Tuple # noqa: UP035
|
||||
from typing import Any, Iterable, List, Optional, Set, Tuple, cast # noqa: UP035
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.engine import Inspector
|
||||
from typing_extensions import Annotated # noqa: UP035
|
||||
|
||||
from metadata.domain.tags import TagCanonicalizer, TagRegistry
|
||||
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
|
||||
from metadata.generated.schema.api.data.createDatabaseSchema import (
|
||||
CreateDatabaseSchemaRequest,
|
||||
|
|
@ -160,6 +161,7 @@ class DatabaseServiceTopology(ServiceTopology):
|
|||
"mark_schemas_as_deleted",
|
||||
"mark_tables_as_deleted",
|
||||
"mark_stored_procedures_as_deleted",
|
||||
"clear_database_tag_scope",
|
||||
],
|
||||
threads=True,
|
||||
)
|
||||
|
|
@ -186,6 +188,7 @@ class DatabaseServiceTopology(ServiceTopology):
|
|||
nullable=True,
|
||||
),
|
||||
],
|
||||
post_process=["clear_schema_tag_scope"],
|
||||
)
|
||||
stored_procedure: Annotated[TopologyNode, Field(description="Stored Procedure Node")] = TopologyNode(
|
||||
producer="get_stored_procedures",
|
||||
|
|
@ -224,6 +227,26 @@ class DatabaseServiceSource(TopologyRunnerMixin, Source, ABC): # pylint: disabl
|
|||
topology = DatabaseServiceTopology()
|
||||
context = TopologyContextManager(topology)
|
||||
|
||||
# ``vars(self).setdefault(...)`` for thread-safe lazy init.
|
||||
# See: https://docs.python.org/3/library/threadsafety.html
|
||||
@property
|
||||
def tags_registry(self) -> TagRegistry:
|
||||
"""Per-Source registry tracking tag/classification ingestion state."""
|
||||
instance_dict = vars(self)
|
||||
cached = instance_dict.get("tags_registry")
|
||||
if cached is not None:
|
||||
return cached
|
||||
return instance_dict.setdefault("tags_registry", TagRegistry(metadata=self.metadata))
|
||||
|
||||
@property
|
||||
def tag_canonicalizer(self) -> TagCanonicalizer:
|
||||
"""Per-Source canonicalizer for case-corrected tag/classification names."""
|
||||
instance_dict = vars(self)
|
||||
cached = instance_dict.get("tag_canonicalizer")
|
||||
if cached is not None:
|
||||
return cached
|
||||
return instance_dict.setdefault("tag_canonicalizer", TagCanonicalizer(metadata=self.metadata))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.service_connection.type.name
|
||||
|
|
@ -811,6 +834,39 @@ class DatabaseServiceSource(TopologyRunnerMixin, Source, ABC): # pylint: disabl
|
|||
Get the life cycle data of the table
|
||||
"""
|
||||
|
||||
def clear_schema_tag_scope(self):
|
||||
"""Drop tag-registry state for the current schema scope."""
|
||||
schema_name = self.context.get().database_schema # pyright: ignore[reportAttributeAccessIssue]
|
||||
if schema_name:
|
||||
schema_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
schema_name=schema_name,
|
||||
),
|
||||
)
|
||||
self.tags_registry.clear_scope(schema_fqn)
|
||||
yield from ()
|
||||
|
||||
def clear_database_tag_scope(self):
|
||||
"""Drop tag-registry state for the current database scope."""
|
||||
database_name = self.context.get().database # pyright: ignore[reportAttributeAccessIssue]
|
||||
if database_name:
|
||||
database_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=database_name,
|
||||
),
|
||||
)
|
||||
self.tags_registry.clear_scope(database_fqn)
|
||||
yield from ()
|
||||
|
||||
def yield_external_table_lineage(self) -> Iterable[Either[AddLineageRequest]]:
|
||||
"""
|
||||
Process external table lineage
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ Snowflake source module
|
|||
import json # noqa: I001
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List, Optional, Tuple # noqa: UP035
|
||||
from typing import Iterable, List, Optional, Tuple, cast # noqa: UP035
|
||||
|
||||
import sqlalchemy.types as sqltypes
|
||||
import sqlparse
|
||||
|
|
@ -37,6 +37,7 @@ from metadata.generated.schema.entity.data.storedProcedure import (
|
|||
StoredProcedureType,
|
||||
)
|
||||
from metadata.generated.schema.entity.data.table import (
|
||||
Column,
|
||||
PartitionColumnDetails,
|
||||
PartitionIntervalTypes,
|
||||
Table,
|
||||
|
|
@ -54,7 +55,6 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
|||
)
|
||||
from metadata.generated.schema.type.basic import (
|
||||
EntityName,
|
||||
FullyQualifiedEntityName,
|
||||
SourceUrl,
|
||||
)
|
||||
from metadata.generated.schema.type.entityReferenceList import EntityReferenceList
|
||||
|
|
@ -135,7 +135,6 @@ from metadata.utils.sqlalchemy_utils import (
|
|||
get_all_table_ddls,
|
||||
get_all_view_definitions,
|
||||
)
|
||||
from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label
|
||||
|
||||
|
||||
class MAP(StructuredType):
|
||||
|
|
@ -548,9 +547,20 @@ class SnowflakeSource(
|
|||
logger.debug(traceback.format_exc())
|
||||
logger.error(f"Failed to fetch tags due to [{inner_exc}]")
|
||||
|
||||
schema_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service,
|
||||
database_name=self.context.get().database,
|
||||
schema_name=schema_name,
|
||||
),
|
||||
)
|
||||
for res in result:
|
||||
row = list(res)
|
||||
fqn_elements = [name for name in row[2:] if name]
|
||||
|
||||
# row[0] = TAG_NAME, row[1] = TAG_VALUE
|
||||
if not row[1]:
|
||||
logger.warning(
|
||||
|
|
@ -558,62 +568,113 @@ class SnowflakeSource(
|
|||
"TAG_VALUE is empty. Snowflake tags require a value to be ingested."
|
||||
)
|
||||
continue
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(
|
||||
fqn._build( # pylint: disable=protected-access
|
||||
self.context.get().database_service, *fqn_elements
|
||||
)
|
||||
),
|
||||
tags=[row[1]],
|
||||
classification_name=row[0],
|
||||
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
metadata=self.metadata,
|
||||
system_tags=True,
|
||||
)
|
||||
|
||||
entity_fqn = fqn._build(self.context.get().database_service, *fqn_elements) # pyright: ignore[reportAttributeAccessIssue]
|
||||
try:
|
||||
classification = self.tag_canonicalizer.classification(
|
||||
row[0], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
|
||||
)
|
||||
tag = self.tag_canonicalizer.tag(
|
||||
classification.name, row[1], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
|
||||
)
|
||||
|
||||
self.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=entity_fqn,
|
||||
classification_name=classification.name,
|
||||
tag_name=tag.name,
|
||||
classification_description=classification.description,
|
||||
tag_description=tag.description,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
yield Either(
|
||||
left=StackTraceError(
|
||||
name=f"{row[0]}.{row[1]}",
|
||||
error=f"Tag canonicalization failed for {row[0]}.{row[1]}: {exc}",
|
||||
stackTrace=traceback.format_exc(),
|
||||
),
|
||||
right=None,
|
||||
)
|
||||
|
||||
# Yield schema-level tags
|
||||
if schema_name in self.schema_tags_map:
|
||||
schema_fqn = fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service,
|
||||
database_name=self.context.get().database,
|
||||
schema_name=schema_name,
|
||||
)
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(schema_fqn),
|
||||
tags=[tag_info["tag_value"]],
|
||||
classification_name=tag_info["tag_name"],
|
||||
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
metadata=self.metadata,
|
||||
system_tags=True,
|
||||
)
|
||||
try:
|
||||
classification = self.tag_canonicalizer.classification(
|
||||
tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
|
||||
)
|
||||
tag = self.tag_canonicalizer.tag(
|
||||
classification.name,
|
||||
tag_info["tag_value"],
|
||||
default_tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
)
|
||||
|
||||
def yield_database_tag(self, database_entity: str) -> Iterable[Either[OMetaTagAndClassification]]:
|
||||
self.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=schema_fqn,
|
||||
classification_name=classification.name,
|
||||
tag_name=tag.name,
|
||||
classification_description=classification.description,
|
||||
tag_description=tag.description,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
yield Either(
|
||||
left=StackTraceError(
|
||||
name=f"{tag_info['tag_name']}.{tag_info['tag_value']}",
|
||||
error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}",
|
||||
stackTrace=traceback.format_exc(),
|
||||
),
|
||||
right=None,
|
||||
)
|
||||
yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
|
||||
|
||||
def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndClassification]]:
|
||||
"""Yield database-level tags for the topology."""
|
||||
if not self.source_config.includeTags:
|
||||
return
|
||||
|
||||
if database_entity in self.database_tags_map:
|
||||
database_fqn = fqn.build(
|
||||
if database_name not in self.database_tags_map:
|
||||
return
|
||||
|
||||
database_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.get().database_service,
|
||||
database_name=database_entity,
|
||||
)
|
||||
for tag_info in self.database_tags_map[database_entity]:
|
||||
yield from get_ometa_tag_and_classification(
|
||||
tag_fqn=FullyQualifiedEntityName(database_fqn),
|
||||
tags=[tag_info["tag_value"]],
|
||||
classification_name=tag_info["tag_name"],
|
||||
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
|
||||
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
|
||||
metadata=self.metadata,
|
||||
system_tags=True,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=database_name,
|
||||
),
|
||||
)
|
||||
for tag_info in self.database_tags_map[database_name]:
|
||||
try:
|
||||
classification = self.tag_canonicalizer.classification(
|
||||
tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
|
||||
)
|
||||
tag = self.tag_canonicalizer.tag(
|
||||
classification.name, tag_info["tag_value"], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
|
||||
)
|
||||
|
||||
self.tags_registry.attach(
|
||||
scope_fqn=database_fqn,
|
||||
entity_fqn=database_fqn,
|
||||
classification_name=classification.name,
|
||||
tag_name=tag.name,
|
||||
classification_description=classification.description,
|
||||
tag_description=tag.description,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(traceback.format_exc())
|
||||
yield Either(
|
||||
left=StackTraceError(
|
||||
name=f"{tag_info['tag_name']}.{tag_info['tag_value']}",
|
||||
error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}",
|
||||
stackTrace=traceback.format_exc(),
|
||||
),
|
||||
right=None,
|
||||
)
|
||||
yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
|
||||
|
||||
def _get_table_names_and_types(
|
||||
self, schema_name: str, table_type: TableType = TableType.Regular
|
||||
|
|
@ -1049,42 +1110,72 @@ class SnowflakeSource(
|
|||
return True
|
||||
return False
|
||||
|
||||
def get_database_tag_labels(self, database_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
|
||||
"""Return tags for the database entity from registry."""
|
||||
database_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=database_name,
|
||||
),
|
||||
)
|
||||
return self.tags_registry.labels_for(database_fqn) or None
|
||||
|
||||
def get_column_tag_labels(self, table_name: str, column: dict) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
|
||||
"""Return tags for a column entity from the registry.
|
||||
|
||||
Column tags don't inherit from parent entities (table/schema/database)
|
||||
— those have separate semantic meaning at their own level. Direct
|
||||
lookup is sufficient.
|
||||
"""
|
||||
col_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Column,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
|
||||
table_name=table_name,
|
||||
column_name=column["name"],
|
||||
),
|
||||
)
|
||||
return self.tags_registry.labels_for(col_fqn) or None
|
||||
|
||||
def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
|
||||
"""
|
||||
Return tags for schema entity including:
|
||||
1. Snowflake schema-level tags
|
||||
2. Inherited database-level tags (only if no tag with same classification exists)
|
||||
"""
|
||||
schema_tags = []
|
||||
schema_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
schema_name=schema_name,
|
||||
),
|
||||
)
|
||||
database_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
),
|
||||
)
|
||||
|
||||
if schema_name in self.schema_tags_map:
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label:
|
||||
schema_tags.append(tag_label)
|
||||
schema_tags = self.tags_registry.labels_for(schema_fqn)
|
||||
|
||||
# Add inherited database tags (only if classification doesn't already exist)
|
||||
database_name = self.context.get().database
|
||||
if database_name and database_name in self.database_tags_map:
|
||||
for tag_info in self.database_tags_map[database_name]:
|
||||
if not self._has_classification(tag_info["tag_name"], schema_tags):
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label:
|
||||
schema_tags.append(tag_label)
|
||||
|
||||
# Include parent tags from context
|
||||
parent_tags = super().get_schema_tag_labels(schema_name) or []
|
||||
for tag in parent_tags:
|
||||
if not self._has_classification(self._get_classification_name(tag), schema_tags):
|
||||
schema_tags.append(tag)
|
||||
for label in self.tags_registry.labels_for(database_fqn):
|
||||
if not self._has_classification(self._get_classification_name(label), schema_tags):
|
||||
schema_tags.append(label)
|
||||
|
||||
return schema_tags if schema_tags else None
|
||||
|
||||
|
|
@ -1098,32 +1189,48 @@ class SnowflakeSource(
|
|||
|
||||
Tag values at lower levels take precedence over inherited values.
|
||||
"""
|
||||
table_tags = super().get_tag_labels(table_name) or []
|
||||
table_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Table,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
|
||||
table_name=table_name,
|
||||
skip_es_search=True,
|
||||
),
|
||||
)
|
||||
schema_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
|
||||
),
|
||||
)
|
||||
database_fqn = cast(
|
||||
"str",
|
||||
fqn.build(
|
||||
self.metadata,
|
||||
entity_type=Database,
|
||||
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
|
||||
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
|
||||
),
|
||||
)
|
||||
|
||||
table_tags = self.tags_registry.labels_for(table_fqn)
|
||||
|
||||
# Add inherited schema tags (only if classification doesn't already exist)
|
||||
schema_name = self.context.get().database_schema
|
||||
if schema_name and schema_name in self.schema_tags_map:
|
||||
for tag_info in self.schema_tags_map[schema_name]:
|
||||
if not self._has_classification(tag_info["tag_name"], table_tags):
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label:
|
||||
table_tags.append(tag_label)
|
||||
for label in self.tags_registry.labels_for(schema_fqn):
|
||||
if not self._has_classification(self._get_classification_name(label), table_tags):
|
||||
table_tags.append(label)
|
||||
|
||||
# Add inherited database tags (only if classification doesn't already exist)
|
||||
database_name = self.context.get().database
|
||||
if database_name and database_name in self.database_tags_map:
|
||||
for tag_info in self.database_tags_map[database_name]:
|
||||
if not self._has_classification(tag_info["tag_name"], table_tags):
|
||||
tag_label = get_tag_label(
|
||||
metadata=self.metadata,
|
||||
tag_name=tag_info["tag_value"],
|
||||
classification_name=tag_info["tag_name"],
|
||||
)
|
||||
if tag_label:
|
||||
table_tags.append(tag_label)
|
||||
for label in self.tags_registry.labels_for(database_fqn):
|
||||
if not self._has_classification(self._get_classification_name(label), table_tags):
|
||||
table_tags.append(label)
|
||||
|
||||
return table_tags if table_tags else None
|
||||
|
|
|
|||
0
ingestion/tests/unit/domain/__init__.py
Normal file
0
ingestion/tests/unit/domain/__init__.py
Normal file
0
ingestion/tests/unit/domain/tags/__init__.py
Normal file
0
ingestion/tests/unit/domain/tags/__init__.py
Normal file
161
ingestion/tests/unit/domain/tags/test_canonicalizer.py
Normal file
161
ingestion/tests/unit/domain/tags/test_canonicalizer.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``metadata.domain.tags.TagCanonicalizer``."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from metadata.domain.tags import Canonical, TagCanonicalizer
|
||||
from metadata.generated.schema.entity.classification.classification import Classification
|
||||
from metadata.generated.schema.type.basic import ProviderType
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_retry_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Skip tenacity's between-retry sleeps so retry-tests run instantly."""
|
||||
monkeypatch.setattr("time.sleep", lambda *_args, **_kwargs: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def canonicalizer(mock_metadata: MagicMock) -> TagCanonicalizer:
|
||||
return TagCanonicalizer(metadata=mock_metadata)
|
||||
|
||||
|
||||
def _system_classification(name: str, description: str = "") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.provider = ProviderType.system
|
||||
m.name.root = name
|
||||
if description:
|
||||
m.description.root = description
|
||||
else:
|
||||
m.description = None
|
||||
return m
|
||||
|
||||
|
||||
def _system_tag(classification: str, name: str, description: str = "") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.provider = ProviderType.system
|
||||
m.classification.name = classification
|
||||
m.name.root = name
|
||||
if description:
|
||||
m.description.root = description
|
||||
else:
|
||||
m.description = None
|
||||
return m
|
||||
|
||||
|
||||
class TestClassification:
|
||||
def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = []
|
||||
result = canonicalizer.classification("MyClass", "Source desc")
|
||||
assert result == Canonical(name="MyClass", description="Source desc")
|
||||
|
||||
def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
|
||||
result = canonicalizer.classification("pii", "Source desc")
|
||||
assert result == Canonical(name="PII", description="Canonical desc")
|
||||
|
||||
def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
|
||||
canonicalizer.classification("pii", "Source desc")
|
||||
canonicalizer.classification("PII", "Source desc")
|
||||
canonicalizer.classification("Pii", "Source desc")
|
||||
# Three case variants share the same case-insensitive cache key
|
||||
assert mock_metadata.es_search_from_fqn.call_count == 1
|
||||
|
||||
def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
non_system = _system_classification("PII", "Canonical desc")
|
||||
non_system.provider = ProviderType.user
|
||||
mock_metadata.es_search_from_fqn.return_value = [non_system]
|
||||
result = canonicalizer.classification("pii", "Source desc")
|
||||
assert result == Canonical(name="pii", description="Source desc")
|
||||
|
||||
def test_classification_es_called_with_correct_args(
|
||||
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
|
||||
):
|
||||
mock_metadata.es_search_from_fqn.return_value = []
|
||||
canonicalizer.classification("Foo", "Source desc")
|
||||
mock_metadata.es_search_from_fqn.assert_called_once_with(entity_type=Classification, fqn_search_string="Foo")
|
||||
|
||||
|
||||
class TestTag:
|
||||
def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = []
|
||||
result = canonicalizer.tag("PII", "MyTag", "Source desc")
|
||||
assert result == Canonical(name="MyTag", description="Source desc")
|
||||
|
||||
def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "Canonical desc")]
|
||||
result = canonicalizer.tag("PII", "sensitive", "Source desc")
|
||||
assert result == Canonical(name="Sensitive", description="Canonical desc")
|
||||
|
||||
def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "")]
|
||||
canonicalizer.tag("PII", "sensitive", "Source desc")
|
||||
canonicalizer.tag("PII", "SENSITIVE", "Source desc")
|
||||
canonicalizer.tag("PII", "Sensitive", "Source desc")
|
||||
# Three case variants share the same case-insensitive cache key
|
||||
assert mock_metadata.es_search_from_fqn.call_count == 1
|
||||
|
||||
def test_match_requires_classification_match(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
# ES returns a tag but for a different classification — no canonicalization
|
||||
wrong_class_tag = _system_tag("OtherClass", "Sensitive", "Canonical desc")
|
||||
mock_metadata.es_search_from_fqn.return_value = [wrong_class_tag]
|
||||
result = canonicalizer.tag("PII", "sensitive", "Source desc")
|
||||
assert result == Canonical(name="sensitive", description="Source desc")
|
||||
|
||||
def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
non_system = _system_tag("PII", "Sensitive", "Canonical desc")
|
||||
non_system.provider = ProviderType.user
|
||||
mock_metadata.es_search_from_fqn.return_value = [non_system]
|
||||
result = canonicalizer.tag("PII", "sensitive", "Source desc")
|
||||
assert result == Canonical(name="sensitive", description="Source desc")
|
||||
|
||||
|
||||
class TestRetryAndFailure:
|
||||
def test_transient_failure_recovers_within_retry_budget(
|
||||
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
|
||||
):
|
||||
# First two ES calls raise; third succeeds.
|
||||
mock_metadata.es_search_from_fqn.side_effect = [
|
||||
RuntimeError("transient 1"),
|
||||
RuntimeError("transient 2"),
|
||||
[_system_classification("PII", "Canonical desc")],
|
||||
]
|
||||
result = canonicalizer.classification("pii", "Source desc")
|
||||
assert result == Canonical(name="PII", description="Canonical desc")
|
||||
assert mock_metadata.es_search_from_fqn.call_count == 3
|
||||
|
||||
def test_persistent_failure_raises_after_retries_exhaust(
|
||||
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
|
||||
):
|
||||
mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent")
|
||||
with pytest.raises(RuntimeError, match="persistent"):
|
||||
canonicalizer.classification("MyClass", "Source desc")
|
||||
assert mock_metadata.es_search_from_fqn.call_count == 5
|
||||
|
||||
def test_persistent_failure_does_not_poison_cache(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
|
||||
# First call: ES persistently fails -> raises.
|
||||
mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent")
|
||||
with pytest.raises(RuntimeError):
|
||||
canonicalizer.classification("MyClass", "Source desc")
|
||||
|
||||
# ES recovers; subsequent call must reach ES again, not return a cached fallback.
|
||||
mock_metadata.es_search_from_fqn.side_effect = None
|
||||
mock_metadata.es_search_from_fqn.return_value = [_system_classification("MyClass", "Canonical desc")]
|
||||
result = canonicalizer.classification("MyClass", "Source desc")
|
||||
assert result == Canonical(name="MyClass", description="Canonical desc")
|
||||
375
ingestion/tests/unit/domain/tags/test_registry.py
Normal file
375
ingestion/tests/unit/domain/tags/test_registry.py
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
# Copyright 2025 Collate
|
||||
# Licensed under the Collate Community License, Version 1.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``metadata.domain.tags.TagRegistry``.
|
||||
|
||||
Covers attach/labels_for/drain/clear_scope/ensure_known semantics plus
|
||||
basic thread-safety stress scenarios. The OM client is mocked; no
|
||||
network or schema validation against a real backend.
|
||||
"""
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from metadata.domain.tags import ScopeAlreadyClearedError, TagRegistry
|
||||
from metadata.generated.schema.type.tagLabel import LabelType, State
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(mock_metadata: MagicMock) -> TagRegistry:
|
||||
return TagRegistry(metadata=mock_metadata)
|
||||
|
||||
|
||||
def _attach_kwargs(
|
||||
scope: str,
|
||||
entity: str,
|
||||
classification: str = "TestClass",
|
||||
tag: str = "TestTag",
|
||||
) -> dict:
|
||||
return {
|
||||
"scope_fqn": scope,
|
||||
"entity_fqn": entity,
|
||||
"classification_name": classification,
|
||||
"tag_name": tag,
|
||||
"classification_description": "test classification",
|
||||
"tag_description": "test tag",
|
||||
}
|
||||
|
||||
|
||||
class TestAttachAndLabelsFor:
|
||||
def test_attach_then_labels_for_returns_one_label(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
|
||||
labels = registry.labels_for("svc.db.schema.table")
|
||||
assert len(labels) == 1
|
||||
|
||||
def test_attach_multiple_tags_same_entity_returns_all(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag1"))
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag2"))
|
||||
labels = registry.labels_for("svc.db.schema.table")
|
||||
assert len(labels) == 2
|
||||
|
||||
def test_labels_for_unattached_entity_returns_empty_list(self, registry: TagRegistry):
|
||||
assert registry.labels_for("svc.db.schema.unknown") == []
|
||||
|
||||
def test_labels_for_is_idempotent(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
|
||||
first = registry.labels_for("svc.db.schema.table")
|
||||
second = registry.labels_for("svc.db.schema.table")
|
||||
# Read-and-leave: both reads return the same labels.
|
||||
# Cleanup is the responsibility of clear_scope, not labels_for.
|
||||
assert len(first) == 1
|
||||
assert second == first
|
||||
|
||||
def test_labels_for_returns_copy_not_internal_list(self, registry: TagRegistry):
|
||||
# Mutating the returned list must not affect registry state.
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
|
||||
first = registry.labels_for("svc.db.schema.table")
|
||||
first.clear()
|
||||
second = registry.labels_for("svc.db.schema.table")
|
||||
assert len(second) == 1
|
||||
|
||||
|
||||
class TestDrain:
|
||||
def test_drain_yields_pending_then_clears(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_a"))
|
||||
first = list(registry.drain())
|
||||
second = list(registry.drain())
|
||||
assert len(first) == 1
|
||||
assert second == []
|
||||
|
||||
def test_drain_dedupes_same_tag_across_entities(self, registry: TagRegistry):
|
||||
for i in range(100):
|
||||
registry.attach(**_attach_kwargs("svc.db", f"svc.db.schema.tbl_{i}"))
|
||||
pending = list(registry.drain())
|
||||
assert len(pending) == 1
|
||||
|
||||
def test_drain_yields_distinct_payloads_for_distinct_tags(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1", tag="TagA"))
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2", tag="TagB"))
|
||||
pending = list(registry.drain())
|
||||
assert len(pending) == 2
|
||||
|
||||
def test_drain_does_not_dedup_across_case_variants(self, registry: TagRegistry):
|
||||
# OM stores tags case-sensitively; our dedup must follow that rule.
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.t1", tag="Sensitive"))
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.t2", tag="sensitive"))
|
||||
pending = list(registry.drain())
|
||||
assert len(pending) == 2 # both must PUT — they're distinct tags server-side
|
||||
|
||||
def test_drain_dedupes_same_fqn_across_label_types(self, registry: TagRegistry):
|
||||
# Different cache keys (label_type varies) but identical tag_fqn → ONE PUT.
|
||||
# Cache key is (class, tag, label_type, state); tag_fqn is class.tag.
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.t1"),
|
||||
label_type=LabelType.Manual,
|
||||
)
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.t2"),
|
||||
label_type=LabelType.Automated,
|
||||
)
|
||||
pending = list(registry.drain())
|
||||
assert len(pending) == 1, "fqn-level dedup must collapse PUTs across label_type variants"
|
||||
|
||||
|
||||
class TestClearScope:
|
||||
def test_clear_scope_drops_descendant_labels(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_1"))
|
||||
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_2"))
|
||||
registry.clear_scope("svc.db.schema")
|
||||
assert registry.labels_for("svc.db.schema.tbl_1") == []
|
||||
assert registry.labels_for("svc.db.schema.tbl_2") == []
|
||||
|
||||
def test_clear_scope_drops_scope_itself(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema"))
|
||||
registry.clear_scope("svc.db.schema")
|
||||
assert registry.labels_for("svc.db.schema") == []
|
||||
|
||||
def test_clear_scope_preserves_other_scopes(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db.schema_a", "svc.db.schema_a.tbl"))
|
||||
registry.attach(**_attach_kwargs("svc.db.schema_b", "svc.db.schema_b.tbl"))
|
||||
registry.clear_scope("svc.db.schema_a")
|
||||
assert registry.labels_for("svc.db.schema_a.tbl") == []
|
||||
assert len(registry.labels_for("svc.db.schema_b.tbl")) == 1
|
||||
|
||||
def test_clear_scope_no_false_prefix_match(self, registry: TagRegistry):
|
||||
# 'schema_a' is NOT a prefix of 'schema_alpha' once the FQN
|
||||
# separator is taken into account.
|
||||
registry.attach(**_attach_kwargs("svc.db.schema_alpha", "svc.db.schema_alpha.tbl"))
|
||||
registry.clear_scope("svc.db.schema_a")
|
||||
assert len(registry.labels_for("svc.db.schema_alpha.tbl")) == 1
|
||||
|
||||
def test_clear_scope_idempotent_on_unattached_scope(self, registry: TagRegistry):
|
||||
registry.clear_scope("svc.db.never_attached") # must not raise
|
||||
|
||||
def test_attach_after_clear_raises(self, registry: TagRegistry):
|
||||
registry.clear_scope("svc.db.schema")
|
||||
with pytest.raises(ScopeAlreadyClearedError):
|
||||
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl"))
|
||||
|
||||
|
||||
class TestEnsureKnown:
|
||||
def test_is_known_empty_returns_false(self, registry: TagRegistry):
|
||||
assert registry.is_known("Class.Tag") is False
|
||||
|
||||
def test_is_known_after_attach_returns_true(self, registry: TagRegistry):
|
||||
registry.attach(
|
||||
**_attach_kwargs(
|
||||
"svc.db",
|
||||
"svc.db.schema.tbl",
|
||||
classification="Class",
|
||||
tag="Tag",
|
||||
)
|
||||
)
|
||||
assert registry.is_known("Class.Tag") is True
|
||||
|
||||
def test_is_known_is_case_sensitive(self, registry: TagRegistry):
|
||||
# Reflects OM's case-sensitive identity rule.
|
||||
registry.attach(
|
||||
**_attach_kwargs(
|
||||
"svc.db",
|
||||
"svc.db.schema.tbl",
|
||||
classification="Class",
|
||||
tag="Tag",
|
||||
)
|
||||
)
|
||||
assert registry.is_known("Class.Tag") is True
|
||||
assert registry.is_known("class.tag") is False # different tag server-side
|
||||
|
||||
def test_ensure_known_cache_hit_skips_io(self, registry: TagRegistry, mock_metadata: MagicMock):
|
||||
registry.attach(
|
||||
**_attach_kwargs(
|
||||
"svc.db",
|
||||
"svc.db.schema.tbl",
|
||||
classification="Class",
|
||||
tag="Tag",
|
||||
)
|
||||
)
|
||||
assert registry.ensure_known("Class.Tag") is True
|
||||
mock_metadata.get_by_name.assert_not_called()
|
||||
|
||||
def test_ensure_known_cache_miss_calls_get_by_name_once(self, registry: TagRegistry, mock_metadata: MagicMock):
|
||||
mock_metadata.get_by_name.return_value = MagicMock()
|
||||
assert registry.ensure_known("Other.Tag") is True
|
||||
assert registry.ensure_known("Other.Tag") is True # cached now
|
||||
assert mock_metadata.get_by_name.call_count == 1
|
||||
|
||||
def test_ensure_known_404_returns_false_and_does_not_cache(self, registry: TagRegistry, mock_metadata: MagicMock):
|
||||
mock_metadata.get_by_name.return_value = None
|
||||
assert registry.ensure_known("Missing.Tag") is False
|
||||
assert registry.ensure_known("Missing.Tag") is False
|
||||
# Re-queries on each miss; not cached.
|
||||
assert mock_metadata.get_by_name.call_count == 2
|
||||
|
||||
def test_ensure_known_swallows_exception(self, registry: TagRegistry, mock_metadata: MagicMock):
|
||||
mock_metadata.get_by_name.side_effect = RuntimeError("network down")
|
||||
assert registry.ensure_known("Crashed.Tag") is False
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
def test_concurrent_attach_same_tag_dedupes_pending(self, registry: TagRegistry):
|
||||
def worker(thread_idx: int) -> None:
|
||||
for i in range(100):
|
||||
registry.attach(
|
||||
**_attach_kwargs(
|
||||
"svc.db",
|
||||
f"svc.db.schema.tbl_{thread_idx}_{i}",
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
list(pool.map(worker, range(8)))
|
||||
|
||||
pending = list(registry.drain())
|
||||
assert len(pending) == 1
|
||||
|
||||
def test_concurrent_disjoint_scopes_no_label_loss(self, registry: TagRegistry):
|
||||
def worker(scope_idx: int) -> None:
|
||||
scope = f"svc.db.schema_{scope_idx}"
|
||||
for i in range(50):
|
||||
registry.attach(
|
||||
**_attach_kwargs(
|
||||
scope,
|
||||
f"{scope}.tbl_{i}",
|
||||
tag=f"Tag_{scope_idx}_{i}",
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
list(pool.map(worker, range(8)))
|
||||
|
||||
for scope_idx in range(8):
|
||||
scope = f"svc.db.schema_{scope_idx}"
|
||||
for i in range(50):
|
||||
entity = f"{scope}.tbl_{i}"
|
||||
labels = registry.labels_for(entity)
|
||||
assert len(labels) == 1, f"missing label for {entity}"
|
||||
|
||||
|
||||
class TestStats:
|
||||
def test_initial_stats_all_zero(self, registry: TagRegistry):
|
||||
assert registry.stats() == {
|
||||
"known_tag_fqns": 0,
|
||||
"tag_label_cache": 0,
|
||||
"pending": 0,
|
||||
"cleared_scopes": 0,
|
||||
"live_entities": 0,
|
||||
"live_labels": 0,
|
||||
}
|
||||
|
||||
def test_stats_reflect_attach(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"))
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"))
|
||||
s = registry.stats()
|
||||
# Both attaches share the same tag — known + pending dedup to 1
|
||||
assert s["known_tag_fqns"] == 1
|
||||
assert s["pending"] == 1
|
||||
# Two entities, each with one label
|
||||
assert s["live_entities"] == 2
|
||||
assert s["live_labels"] == 2
|
||||
|
||||
def test_labels_for_does_not_decrease_live_state(self, registry: TagRegistry):
|
||||
# labels_for is idempotent (read-and-leave); clear_scope is the
|
||||
# only mechanism that reduces live state.
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl"))
|
||||
registry.labels_for("svc.db.schema.tbl")
|
||||
s = registry.stats()
|
||||
assert s["live_entities"] == 1
|
||||
assert s["live_labels"] == 1
|
||||
assert s["known_tag_fqns"] == 1
|
||||
assert s["pending"] == 1
|
||||
|
||||
def test_drain_decreases_pending_only(self, registry: TagRegistry):
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl"))
|
||||
list(registry.drain())
|
||||
s = registry.stats()
|
||||
assert s["pending"] == 0
|
||||
assert s["known_tag_fqns"] == 1 # still tracked for dedup
|
||||
|
||||
def test_clear_scope_zeroes_live_state_for_scope(self, registry: TagRegistry):
|
||||
# Critical invariant: after clear_scope, no live_entities for that scope.
|
||||
for i in range(50):
|
||||
registry.attach(**_attach_kwargs("svc.db.schema", f"svc.db.schema.tbl_{i}"))
|
||||
assert registry.stats()["live_entities"] == 50
|
||||
|
||||
registry.clear_scope("svc.db.schema")
|
||||
s = registry.stats()
|
||||
assert s["live_entities"] == 0
|
||||
assert s["live_labels"] == 0
|
||||
assert s["cleared_scopes"] == 1
|
||||
|
||||
|
||||
class TestInterning:
|
||||
"""TagLabel interning — multiple attaches with the same key share one
|
||||
underlying ``TagLabel`` instance. Memory bound depends on this; the
|
||||
`is`-identity assertion is the load-bearing check."""
|
||||
|
||||
def test_attach_interns_identical_tag_labels(self, registry: TagRegistry):
|
||||
# Same (classification, tag, label_type, state) across two entities
|
||||
# must return the exact same TagLabel object — not just an equal one.
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"))
|
||||
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"))
|
||||
|
||||
label_1 = registry.labels_for("svc.db.schema.tbl_1")[0]
|
||||
label_2 = registry.labels_for("svc.db.schema.tbl_2")[0]
|
||||
|
||||
assert label_1 is label_2, "expected shared TagLabel instance via interning"
|
||||
|
||||
def test_attach_does_not_intern_across_label_types(self, registry: TagRegistry):
|
||||
# Cache key includes label_type — non-default values must not collide.
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"),
|
||||
label_type=LabelType.Manual,
|
||||
)
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"),
|
||||
label_type=LabelType.Automated,
|
||||
)
|
||||
|
||||
label_manual = registry.labels_for("svc.db.schema.tbl_1")[0]
|
||||
label_auto = registry.labels_for("svc.db.schema.tbl_2")[0]
|
||||
|
||||
assert label_manual is not label_auto
|
||||
assert label_manual.labelType == LabelType.Manual
|
||||
assert label_auto.labelType == LabelType.Automated
|
||||
|
||||
def test_attach_does_not_intern_across_states(self, registry: TagRegistry):
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"),
|
||||
state=State.Suggested,
|
||||
)
|
||||
registry.attach(
|
||||
**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"),
|
||||
state=State.Confirmed,
|
||||
)
|
||||
|
||||
label_suggested = registry.labels_for("svc.db.schema.tbl_1")[0]
|
||||
label_confirmed = registry.labels_for("svc.db.schema.tbl_2")[0]
|
||||
|
||||
assert label_suggested is not label_confirmed
|
||||
|
||||
def test_intern_cache_survives_clear_scope(self, registry: TagRegistry):
|
||||
# Cache lifetime is registry lifetime, NOT scope lifetime — next scope
|
||||
# reuses the same TagLabel instance for the same (class, tag, ...) key.
|
||||
registry.attach(**_attach_kwargs("svc.db.schema_1", "svc.db.schema_1.tbl"))
|
||||
label_first = registry.labels_for("svc.db.schema_1.tbl")[0]
|
||||
|
||||
registry.clear_scope("svc.db.schema_1")
|
||||
|
||||
registry.attach(**_attach_kwargs("svc.db.schema_2", "svc.db.schema_2.tbl"))
|
||||
label_second = registry.labels_for("svc.db.schema_2.tbl")[0]
|
||||
|
||||
assert label_first is label_second, "intern cache should survive clear_scope"
|
||||
|
|
@ -19,7 +19,9 @@ from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
|||
|
||||
import sqlalchemy.types as sqltypes
|
||||
|
||||
from metadata.generated.schema.entity.data.table import TableType
|
||||
from metadata.generated.schema.entity.data.database import Database
|
||||
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
|
||||
from metadata.generated.schema.entity.data.table import Table, TableType
|
||||
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
|
||||
PipelineStatus,
|
||||
)
|
||||
|
|
@ -27,14 +29,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
|
|||
OpenMetadataWorkflowConfig,
|
||||
)
|
||||
from metadata.generated.schema.type.filterPattern import FilterPattern
|
||||
from metadata.generated.schema.type.tagLabel import (
|
||||
LabelType,
|
||||
State,
|
||||
TagLabel,
|
||||
TagSource,
|
||||
)
|
||||
from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource
|
||||
from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure
|
||||
from metadata.utils import fqn
|
||||
|
||||
SNOWFLAKE_CONFIGURATION = {
|
||||
"source": {
|
||||
|
|
@ -491,18 +488,39 @@ class SnowflakeUnitTest(TestCase):
|
|||
self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default
|
||||
self.assertFalse(map_type.not_null) # default
|
||||
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
|
||||
def test_schema_tag_inheritance(
|
||||
self,
|
||||
mock_get_tag_label,
|
||||
mock_parent_get_schema_tag_labels,
|
||||
mock_parent_get_tag_labels,
|
||||
):
|
||||
"""Test schema tag inheritance"""
|
||||
def _setup_tag_context(self, source, service_name="local_snowflake"):
|
||||
"""Populate the topology context for schema-stage tag tests and return the FQN trio."""
|
||||
source.context.get().__dict__["database_service"] = service_name
|
||||
source.context.get().__dict__["database"] = "TEST_DATABASE"
|
||||
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
|
||||
|
||||
database_fqn = fqn.build(
|
||||
source.metadata,
|
||||
entity_type=Database,
|
||||
service_name=service_name,
|
||||
database_name="TEST_DATABASE",
|
||||
)
|
||||
schema_fqn = fqn.build(
|
||||
source.metadata,
|
||||
entity_type=DatabaseSchema,
|
||||
service_name=service_name,
|
||||
database_name="TEST_DATABASE",
|
||||
schema_name="TEST_SCHEMA",
|
||||
)
|
||||
table_fqn = fqn.build(
|
||||
source.metadata,
|
||||
entity_type=Table,
|
||||
service_name=service_name,
|
||||
database_name="TEST_DATABASE",
|
||||
schema_name="TEST_SCHEMA",
|
||||
table_name="TEST_TABLE",
|
||||
skip_es_search=True,
|
||||
)
|
||||
return database_fqn, schema_fqn, table_fqn
|
||||
|
||||
def test_schema_tag_inheritance(self):
|
||||
"""Schema tags propagate to tables; classification dedup is preserved."""
|
||||
for source in self.sources.values():
|
||||
# Verify tags are fetched and stored
|
||||
mock_schema_tags = [
|
||||
Mock(SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE"),
|
||||
]
|
||||
|
|
@ -519,48 +537,39 @@ class SnowflakeUnitTest(TestCase):
|
|||
{"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"},
|
||||
)
|
||||
|
||||
# Verify schema tag labels
|
||||
mock_get_tag_label.return_value = TagLabel(
|
||||
tagFQN="SnowflakeTag.SCHEMA_TAG",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
_, schema_fqn, table_fqn = self._setup_tag_context(source)
|
||||
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=schema_fqn,
|
||||
classification_name="SCHEMA_CLASSIFICATION",
|
||||
tag_name="SCHEMA_TAG",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=table_fqn,
|
||||
classification_name="TABLE_CLASSIFICATION",
|
||||
tag_name="TABLE_TAG",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
mock_parent_get_schema_tag_labels.return_value = None
|
||||
|
||||
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
|
||||
self.assertIsNotNone(schema_labels)
|
||||
self.assertEqual(len(schema_labels), 1)
|
||||
|
||||
# Verify tag inheritance
|
||||
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
|
||||
mock_parent_get_tag_labels.return_value = [
|
||||
TagLabel(
|
||||
tagFQN="SnowflakeTag.TABLE_TAG",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
]
|
||||
self.assertEqual(schema_labels[0].tagFQN.root, "SCHEMA_CLASSIFICATION.SCHEMA_TAG")
|
||||
|
||||
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
|
||||
self.assertEqual(len(table_labels), 2)
|
||||
tag_fqns = [tag.tagFQN.root for tag in table_labels]
|
||||
self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns)
|
||||
self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns)
|
||||
self.assertIn("SCHEMA_CLASSIFICATION.SCHEMA_TAG", tag_fqns)
|
||||
self.assertIn("TABLE_CLASSIFICATION.TABLE_TAG", tag_fqns)
|
||||
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
|
||||
def test_database_tag_inheritance(
|
||||
self,
|
||||
mock_get_tag_label,
|
||||
mock_parent_get_schema_tag_labels,
|
||||
mock_parent_get_tag_labels,
|
||||
):
|
||||
"""Test database tag inheritance to schemas and tables"""
|
||||
def test_database_tag_inheritance(self):
|
||||
"""Database tags propagate to schemas and tables when classifications don't overlap."""
|
||||
for source in self.sources.values():
|
||||
# Setup mock database tags
|
||||
mock_database_tags = [
|
||||
Mock(
|
||||
DATABASE_NAME="TEST_DATABASE",
|
||||
|
|
@ -574,7 +583,6 @@ class SnowflakeUnitTest(TestCase):
|
|||
source.engine.connect.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||
source.engine.connect.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Test set_database_tags_map
|
||||
source.set_database_tags_map("TEST_DATABASE")
|
||||
self.assertEqual(len(source.database_tags_map["TEST_DATABASE"]), 1)
|
||||
self.assertEqual(
|
||||
|
|
@ -582,23 +590,33 @@ class SnowflakeUnitTest(TestCase):
|
|||
{"tag_name": "DATABASE_TAG", "tag_value": "DB_VALUE"},
|
||||
)
|
||||
|
||||
# Setup schema tags for combined testing
|
||||
source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "SCHEMA_TAG", "tag_value": "SCHEMA_VALUE"}]}
|
||||
database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source)
|
||||
|
||||
# Mock tag label creation
|
||||
def mock_tag_label_side_effect(metadata, tag_name, classification_name):
|
||||
return TagLabel(
|
||||
tagFQN=f"{classification_name}.{tag_name}",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=database_fqn,
|
||||
entity_fqn=database_fqn,
|
||||
classification_name="DATABASE_TAG",
|
||||
tag_name="DB_VALUE",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=schema_fqn,
|
||||
classification_name="SCHEMA_TAG",
|
||||
tag_name="SCHEMA_VALUE",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=table_fqn,
|
||||
classification_name="TABLE_TAG",
|
||||
tag_name="TABLE_VALUE",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
|
||||
mock_get_tag_label.side_effect = mock_tag_label_side_effect
|
||||
mock_parent_get_schema_tag_labels.return_value = None
|
||||
|
||||
# Test schema inherits database tags
|
||||
source.context.get().__dict__["database"] = "TEST_DATABASE"
|
||||
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
|
||||
self.assertIsNotNone(schema_labels)
|
||||
self.assertEqual(len(schema_labels), 2)
|
||||
|
|
@ -606,17 +624,6 @@ class SnowflakeUnitTest(TestCase):
|
|||
self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns)
|
||||
self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns)
|
||||
|
||||
# Test table inherits both schema and database tags
|
||||
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
|
||||
mock_parent_get_tag_labels.return_value = [
|
||||
TagLabel(
|
||||
tagFQN="TABLE_TAG.TABLE_VALUE",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
]
|
||||
|
||||
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
|
||||
self.assertEqual(len(table_labels), 3)
|
||||
tag_fqns = [tag.tagFQN.root for tag in table_labels]
|
||||
|
|
@ -624,59 +631,44 @@ class SnowflakeUnitTest(TestCase):
|
|||
self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns)
|
||||
self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns)
|
||||
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
|
||||
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
|
||||
def test_tag_value_precedence(
|
||||
self,
|
||||
mock_get_tag_label,
|
||||
mock_parent_get_schema_tag_labels,
|
||||
mock_parent_get_tag_labels,
|
||||
):
|
||||
"""Test that tag values at lower levels take precedence over inherited values.
|
||||
def test_tag_value_precedence(self):
|
||||
"""Lower-level tags override inherited values for the same classification.
|
||||
|
||||
When database, schema, and table all have the same tag name (classification)
|
||||
but different values, the object's own value should take precedence.
|
||||
Database: ENV=dev, Schema: ENV=staging, Table: ENV=production.
|
||||
Schema lookup must return only ENV.staging; table lookup only ENV.production.
|
||||
"""
|
||||
for source in self.sources.values():
|
||||
# Setup: Database, schema, and table all have ENV tag with different values
|
||||
# Database: ENV=dev
|
||||
# Schema: ENV=staging
|
||||
# Table: ENV=production
|
||||
database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source)
|
||||
|
||||
source.database_tags_map = {"TEST_DATABASE": [{"tag_name": "ENV", "tag_value": "dev"}]}
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=database_fqn,
|
||||
entity_fqn=database_fqn,
|
||||
classification_name="ENV",
|
||||
tag_name="dev",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=schema_fqn,
|
||||
classification_name="ENV",
|
||||
tag_name="staging",
|
||||
classification_description="",
|
||||
tag_description="",
|
||||
)
|
||||
source.tags_registry.attach(
|
||||
scope_fqn=schema_fqn,
|
||||
entity_fqn=table_fqn,
|
||||
classification_name="ENV",
|
||||
tag_name="production",
|
||||
classification_description="env classification",
|
||||
tag_description="production tag",
|
||||
)
|
||||
|
||||
source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "ENV", "tag_value": "staging"}]}
|
||||
|
||||
def mock_tag_label_side_effect(metadata, tag_name, classification_name):
|
||||
return TagLabel(
|
||||
tagFQN=f"{classification_name}.{tag_name}",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
|
||||
mock_get_tag_label.side_effect = mock_tag_label_side_effect
|
||||
mock_parent_get_schema_tag_labels.return_value = None
|
||||
|
||||
source.context.get().__dict__["database"] = "TEST_DATABASE"
|
||||
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
|
||||
|
||||
# Test schema level: schema's own value takes precedence over database
|
||||
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
|
||||
self.assertEqual(len(schema_labels), 1)
|
||||
self.assertEqual(schema_labels[0].tagFQN.root, "ENV.staging")
|
||||
|
||||
# Test table level: table's own value takes precedence over schema and database
|
||||
mock_parent_get_tag_labels.return_value = [
|
||||
TagLabel(
|
||||
tagFQN="ENV.production",
|
||||
labelType=LabelType.Automated,
|
||||
state=State.Suggested,
|
||||
source=TagSource.Classification,
|
||||
)
|
||||
]
|
||||
|
||||
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
|
||||
self.assertEqual(len(table_labels), 1)
|
||||
self.assertEqual(table_labels[0].tagFQN.root, "ENV.production")
|
||||
|
|
|
|||
Loading…
Reference in a new issue