feat(ingestion): introduce TagRegistry domain layer (#27991)

* feat(ingestion): introduce TagRegistry domain layer

Adds metadata.domain.tags with TagRegistry (per-Source bookkeeping) and
TagCanonicalizer (case-corrected name resolution against OM). Migrates
the Snowflake connector to the new architecture; other connectors stay
on the legacy context.tags flow (strangler pattern).

TagRegistry interns shared TagLabel instances by (classification, tag,
label_type, state) and rebinds the per-entity dict on scope clear. On a
schema with ~120k tag attachments and 21 unique tags, peak heap drops
from ~112 MB to ~24 MB.

Public get_*_tag_labels methods on the database service base are
unchanged; non-Snowflake DB connectors are not touched.

* fix(ingestion): tighten TagRegistry contract and address PR/CI feedback

- Drop None defaults from `TagRegistry.attach()` description params
  (required `str | None`); normalize None -> "" inside
  `_build_pending_record` so the OM schema's required-non-null Markdown
  contract is owned at the registry boundary, not at every caller.
- Fix `yield_database_tag` missing registry drain (PR review).
- Fix `clear_scope` lock-order race that left labels visible after the
  scope was marked cleared (PR review).
- Resolve basedpyright errors:
  - `Either` and `TopologyNode` use Form 3 Pydantic v2 fields
    (`Annotated[X, Field(...)] = default`) so static checkers see the
    defaults.
  - `cast("str", fqn.build(...))` at the 13 sites that feed
    `fqn.build(...)` results into FQN-typed args.
  - Scoped `# pyright: ignore[reportAttributeAccessIssue]` on
    `TopologyContext` dynamic-attribute accesses (matches the codebase
    pattern of grandfathering 8.4k+ such errors via baseline).
- Populate `stackTrace` on the three snowflake tag-error
  `StackTraceError`s so the Status UI shows the trace, not just the
  one-line summary.
- Rewrite three snowflake tag-inheritance tests to drive the real
  registry attach + inheritance walk after `get_tag_label` was removed:
  - `test_schema_tag_inheritance`
  - `test_database_tag_inheritance`
  - `test_tag_value_precedence` (one attach intentionally passes
    `None` descriptions to exercise the registry's normalization)

* fix(ingestion): revert Either to Form 2 to avoid Generic-T inference cascade

Reverting `Either.left`/`Either.right` to the original
`Annotated[Optional[T], Field(default=None, ...)]` form. The Form 3
shape (`Annotated[T | None, Field(...)] = None`) introduced in the
prior commit caused pyright to eagerly bind `T` from the literal-None
default at every no-arg construction site like `Either(left=...)`,
resolving them to `Either[Unknown]`. Because `Either[T]` is invariant,
those failed to satisfy declared generator return types like
`Iterable[Either[CreateTableRequest]]` — surfacing 45 latent
reportReturnType errors across sample_data, dbt, sas, qliksense, sigma,
common_db_source, common_broker_source, amundsen, sink/metadata_rest.

Form 2 wraps the default inside `Annotated` metadata where pyright
treats it as opaque: it sees "this field has a default" (so construction
sites pass) but doesn't eagerly bind `T` from None. Context-driven
inference works, no cascade.

`TopologyNode` stays in Form 3 — its fields are concrete-typed
(`list[str] | None`, `bool`), no Generic-T to bind.

* fix(ingestion): close attach race, thread-safe lazy registries, satisfy CI pyright

Three related fixes from PR review + CI:

1. **TagRegistry.attach race fully closed (Copilot review).** The cleared-scope
   check and the labels-mutation now happen under the same ``_scope_state_lock``,
   so a concurrent ``clear_scope`` cannot interleave between the check and the
   write. Moved ``_cleared_scopes`` ownership from ``_run_state_lock`` to
   ``_scope_state_lock`` accordingly. ``clear_scope`` now adds to
   ``_cleared_scopes`` and rebuilds ``_labels_by_entity`` atomically under the
   single lock.

2. **Thread-safe lazy initialization of ``tags_registry`` and
   ``tag_canonicalizer`` (Copilot review).** ``functools.cached_property`` is
   not thread-safe in Python 3.12+ — under parallel ``databaseSchema`` workers,
   two threads could each construct their own registry and the second would
   overwrite the first in ``__dict__``, losing tag-attaches that the orphaned
   instance already received. Switched to ``@property`` + ``vars(self).setdefault``,
   which is documented atomic under the GIL (Python 3.14 Thread Safety Guarantees)
   and survives PEP 703 free-threaded mode via per-dict locks. Verified
   thread-safe against 30 racing threads in an ad-hoc check (1 distinct
   instance observed, all threads converge).

3. **Explicit ``right=None`` / ``left=None`` on Either error/success yields.**
   Pyright (CI strict mode) treats Either's ``Annotated[Optional[T], Field(default=None, ...)]``
   defaults as invisible, flagging every ``Either(left=...)`` and
   ``Either(right=...)`` as missing the counterpart. Passing the other side
   explicitly satisfies the call contract without changing runtime behavior.

* test(ingestion): patch time.sleep instead of tenacity.nap.time.sleep

The previous patch target ``tenacity.nap.time.sleep`` reaches into
tenacity's internal module structure — fragile under any tenacity
refactor that renames or relocates ``nap.py``. Patching ``time.sleep``
directly is equally effective (tenacity does ``import time`` and calls
``time.sleep`` dynamically, so attribute-level patching intercepts it)
and depends only on stdlib, which is stable.

Addresses Copilot review on test_canonicalizer.py:25.

* refactor(ingestion): single lock in TagRegistry, intern after cleared-check

Two refinements from a fresh round of Copilot review:

1. **Merge ``_run_state_lock`` and ``_scope_state_lock`` into a single
   ``_lock``.** Two locks were a smell: every ``attach()`` already needed
   both, the lock-acquisition order varied across methods (``attach`` does
   one order, ``stats`` the reverse — fragile if any future code held both
   simultaneously), and ``RLock`` was used defensively without any actual
   re-entry. One lock = single mental model, no ordering invariant, deadlock
   impossible.

2. **Move ``_intern_tag_label`` invocation inside the cleared-scope
   gate** (Copilot review). With one lock, the cleared-check, intern,
   label-append, and pending-update all happen in one atomic critical
   section — so when a scope is already cleared, no TagLabel is interned
   into ``_tag_label_cache``. Verified deterministically: 30 threads
   attaching to a pre-cleared scope all raise and the cache stays empty.

3. Reword the inline ``vars(self).setdefault`` thread-safety comment in
   ``database_service.py`` to drop the misleading Python 3.14 reference
   (project targets 3.10+) — just point at the official
   ``threadsafety.html`` doc.

* refactor(ingestion): tighten description contract; rename canonicalizer params

End-to-end tightening to prevent the empty-description / overwrite path
flagged in PR review. Three coordinated changes:

1. **``Canonical.description`` is now ``str`` (required)** instead of
   ``str | None``. The canonicalizer always seeds with the caller-provided
   default and only overrides with a non-empty server value, so the
   resolved description is invariably a real string. Removing the
   Optional makes that invariant visible at the type level.

2. **Rename canonicalizer parameters to make their fallback semantics
   honest**:
   - ``classification(name, description)`` →
     ``classification(name, default_description)``
   - ``tag(classification_name, tag_name, tag_description)`` →
     ``tag(classification_name, tag_name, default_tag_description)``

   The parameter is used to seed the Canonical when no system match
   exists in OM, and as a fallback when an OM match has an empty
   description — never as the value to set on an existing entity. The
   rename signals "if I have to invent this, here's what to write down"
   rather than "set the description to this." Snowflake call sites use
   the new kwargs explicitly for self-documenting reads.

3. **Tighten ``TagRegistry.attach`` and ``_build_pending_record`` to
   required ``str`` for both descriptions** (was ``str | None``). With
   ``Canonical.description`` now ``str``, every caller can pass through
   without an ``or ""`` shim. ``_build_pending_record`` drops the
   defensive ``Markdown(x or "")`` for clean ``Markdown(x)``.

Net effect: no possible code path sends ``None`` or relies on an
``or ""`` fallback that would clobber a server-side description. The
schema's required-Markdown contract is enforced at every layer above
the wire.
This commit is contained in:
IceS2 2026-05-11 17:56:31 +02:00 committed by GitHub
parent 8ad28268e8
commit c4d9a86804
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1346 additions and 227 deletions

View file

@ -165,6 +165,7 @@ base_requirements = {
"sqlalchemy>=2.0.0,<3",
"collate-sqllineage>=2.1.1",
"tabulate==0.9.0",
"tenacity>=8.0,<10",
"typing-inspect",
"packaging", # For version parsing
"setuptools>=78.1.1,<81", # <81 required: pkg_resources removed in setuptools 81+

View file

@ -0,0 +1,22 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenMetadata domain utilities.
In-memory helpers operating on OpenMetadata's data model, reusable across
service-source bases and features. A module belongs here when it satisfies
ALL of:
1. Knows OM concepts (operates on OM-generated types or OM-specific ideas).
2. Owns no I/O infrastructure. May use an INJECTED OM client for read-only
queries; the client's lifecycle is the caller's.
3. Framework-independent no topology, stages, or sinks.
4. Cross-cutting used by more than one service-source base or feature.
"""

View file

@ -0,0 +1,21 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tag and Classification domain utilities."""
from metadata.domain.tags.canonicalizer import Canonical, TagCanonicalizer
from metadata.domain.tags.registry import ScopeAlreadyClearedError, TagRegistry
__all__ = [
"Canonical",
"ScopeAlreadyClearedError",
"TagCanonicalizer",
"TagRegistry",
]

View file

@ -0,0 +1,145 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TagCanonicalizer — case-corrected name resolution against OpenMetadata.
Resolves source-system Classification and Tag names to the canonical form
of any matching system-provider entity in OM (e.g., source reports
``pii.sensitive`` returns ``PII.Sensitive``). Persistent ES failures
raise after retry exhaustion.
"""
import logging
import threading
from collections.abc import Iterable
from typing import Any, NamedTuple, cast
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_random_exponential,
)
from metadata.generated.schema.entity.classification.classification import Classification
from metadata.generated.schema.entity.classification.tag import Tag
from metadata.generated.schema.type.basic import ProviderType
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.utils import fqn
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
_es_retry = retry(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=2, max=30),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
class Canonical(NamedTuple):
"""Canonical (name, description) pair returned from OpenMetadata."""
name: str
description: str
class TagCanonicalizer:
"""Case-corrected name resolution for system Classifications and Tags.
Persistent ES failures raise; callers should wrap in ``Either`` to
surface them to workflow status.
"""
def __init__(self, metadata: OpenMetadata) -> None:
self._metadata = metadata
self._classification_cache: dict[str, Canonical] = {}
self._tag_cache: dict[str, Canonical] = {}
self._lock = threading.RLock()
def classification(
self,
name: str,
default_description: str,
) -> Canonical:
"""Return canonical classification name + description from OM, cached.
``default_description`` is used to seed the Canonical when no
system-provider match exists in OM, and as a fallback when an
OM match has an empty description. An OM-side description wins
over the default whenever available.
"""
key = name.lower()
with self._lock:
cached = self._classification_cache.get(key)
if cached is not None:
return cached
results = self._es_search(Classification, name)
canonical = Canonical(name=name, description=default_description)
for entity in results:
if entity.provider == ProviderType.system and entity.name.root.lower() == key:
canonical = Canonical(
name=entity.name.root,
description=entity.description.root if entity.description else default_description,
)
break
with self._lock:
self._classification_cache.setdefault(key, canonical)
return canonical
def tag(
self,
classification_name: str,
tag_name: str,
default_tag_description: str,
) -> Canonical:
"""Return canonical tag name + description from OM, cached.
``classification_name`` must already be canonical (call ``classification`` first).
``default_tag_description`` is used to seed the Canonical when no
system-provider match exists in OM, and as a fallback when an
OM match has an empty description.
"""
tag_fqn = cast(
"str",
fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name),
)
key = tag_fqn.lower()
with self._lock:
cached = self._tag_cache.get(key)
if cached is not None:
return cached
results = self._es_search(Tag, tag_fqn)
canonical = Canonical(name=tag_name, description=default_tag_description)
for entity in results:
if (
entity.provider == ProviderType.system
and entity.classification.name == classification_name
and entity.name.root.lower() == tag_name.lower()
):
canonical = Canonical(
name=entity.name.root,
description=entity.description.root if entity.description else default_tag_description,
)
break
with self._lock:
self._tag_cache.setdefault(key, canonical)
return canonical
@_es_retry
def _es_search(self, entity_type: Any, search_string: str) -> Iterable[Any]:
"""Run an ES search by FQN with retries."""
return self._metadata.es_search_from_fqn(entity_type=entity_type, fqn_search_string=search_string) or []

View file

@ -0,0 +1,235 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TagRegistry — per-Source bookkeeping for Tag and Classification ingestion.
Holds two concerns:
* a queue of classification/tag create-payloads bound for the sink
(deduped by FQN, drained per scope), and
* a per-entity-FQN lookup of ``TagLabel`` instances for inheritance
reads, dropped at scope boundaries.
Dedup is case-sensitive, matching OpenMetadata's tag-identity rule.
Safe for concurrent use across the topology's parallel schema workers.
"""
import threading
from collections.abc import Iterable
from typing import NamedTuple, cast
from metadata.generated.schema.api.classification.createClassification import (
CreateClassificationRequest,
)
from metadata.generated.schema.api.classification.createTag import CreateTagRequest
from metadata.generated.schema.entity.classification.tag import Tag
from metadata.generated.schema.type.basic import (
EntityName,
FullyQualifiedEntityName,
Markdown,
)
from metadata.generated.schema.type.tagLabel import (
LabelType,
State,
TagFQN,
TagLabel,
TagSource,
)
from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.ometa.utils import model_str
from metadata.utils import fqn
from metadata.utils.logger import ingestion_logger
logger = ingestion_logger()
class _TagLabelKey(NamedTuple):
"""Identity tuple for the TagLabel cache."""
classification_name: str
tag_name: str
label_type: LabelType
state: State
class ScopeAlreadyClearedError(RuntimeError):
"""Raised when 'attach' is called for a previously cleared scope.
Surfaces topology lifecycle bug loudly rather than silently re-creating a cleared scope.
"""
class TagRegistry:
"""Registry for Tag and Classification ingestion bookkeeping."""
def __init__(self, metadata: OpenMetadata) -> None:
self._metadata = metadata
self._known_tag_fqns: set[str] = set()
self._tag_label_cache: dict[_TagLabelKey, TagLabel] = {}
self._pending: list[OMetaTagAndClassification] = []
self._cleared_scopes: set[str] = set()
self._labels_by_entity: dict[str, list[TagLabel]] = {}
self._lock = threading.Lock()
def _intern_tag_label_locked(
self, *, classification_name: str, tag_name: str, label_type: LabelType, state: State
) -> TagLabel:
"""Return the shared ``TagLabel`` for the given key. Caller must hold ``self._lock``."""
key = _TagLabelKey(classification_name, tag_name, label_type, state)
cached = self._tag_label_cache.get(key)
if cached is not None:
return cached
tag_fqn = cast("str", fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name))
cached = TagLabel( # pyright: ignore[reportCallIssue]
tagFQN=TagFQN(tag_fqn),
labelType=label_type,
state=state,
source=TagSource.Classification,
)
self._tag_label_cache[key] = cached
return cached
def attach(
self,
*,
scope_fqn: str,
entity_fqn: str,
classification_name: str,
tag_name: str,
classification_description: str,
tag_description: str,
label_type: LabelType = LabelType.Automated,
state: State = State.Suggested,
) -> None:
"""Register a tag <-> entity association."""
if not tag_name or not tag_name.strip():
logger.debug("TagRegistry: skipping empty tag for classification %s", classification_name)
return
with self._lock:
if scope_fqn in self._cleared_scopes:
raise ScopeAlreadyClearedError(
f"Tag attach called for cleared scope '{scope_fqn!r}' for entity '{entity_fqn!r}'"
)
tag_label = self._intern_tag_label_locked(
classification_name=classification_name,
tag_name=tag_name,
label_type=label_type,
state=state,
)
self._labels_by_entity.setdefault(entity_fqn, []).append(tag_label)
tag_fqn = model_str(tag_label.tagFQN)
if tag_fqn not in self._known_tag_fqns:
self._known_tag_fqns.add(tag_fqn)
self._pending.append(
self._build_pending_record(
classification_name=classification_name,
classification_description=classification_description,
tag_name=tag_name,
tag_description=tag_description,
)
)
def labels_for(self, entity_fqn: str) -> list[TagLabel]:
"""Return tag labels attached to ``entity_fqn`` (idempotent; returns a copy)."""
with self._lock:
return list(self._labels_by_entity.get(entity_fqn, []))
def drain(self) -> Iterable[OMetaTagAndClassification]:
"""Yield all queued create payloads and clear the queue."""
with self._lock:
pending, self._pending = self._pending, []
if pending:
logger.debug("TagRegistry: drained %d pending tag payloads.", len(pending))
yield from pending
def clear_scope(self, scope_fqn: str) -> None:
"""Drop labels under ``scope_fqn`` and mark the scope cleared.
Subsequent ``attach`` calls for this scope will raise.
"""
prefix = scope_fqn + fqn.FQN_SEPARATOR
with self._lock:
self._cleared_scopes.add(scope_fqn)
kept = {k: v for k, v in self._labels_by_entity.items() if k != scope_fqn and not k.startswith(prefix)}
dropped = len(self._labels_by_entity) - len(kept)
self._labels_by_entity = kept
if dropped:
logger.debug("TagRegistry: cleared scope %s (%d entity labels dropped)", scope_fqn, dropped)
def is_known(self, tag_fqn: str) -> bool:
"""Return True if the tag FQN has been recorded (case-sensitive match)."""
with self._lock:
return tag_fqn in self._known_tag_fqns
def ensure_known(self, tag_fqn: str) -> bool:
"""Return True if the tag exists server-side, caching positive results.
Returns False (and does NOT cache) on 404 or transport error.
"""
if self.is_known(tag_fqn):
return True
logger.debug("TagRegistry: cache miss for %s; fetching from OpenMetadata.", tag_fqn)
try:
entity = self._metadata.get_by_name(entity=Tag, fqn=tag_fqn)
except Exception:
logger.exception("TagRegistry: tag lookup failed for %s.", tag_fqn)
return False
if entity is None:
logger.warning(
"TagRegistry: tag %s not found in OpenMetadata; labels referencing it will be skipped.", tag_fqn
)
return False
with self._lock:
self._known_tag_fqns.add(tag_fqn)
return True
def stats(self) -> dict[str, int]:
"""Return current state counts for instrumentation."""
with self._lock:
return {
"known_tag_fqns": len(self._known_tag_fqns),
"tag_label_cache": len(self._tag_label_cache),
"pending": len(self._pending),
"cleared_scopes": len(self._cleared_scopes),
"live_entities": len(self._labels_by_entity),
"live_labels": sum(len(v) for v in self._labels_by_entity.values()),
}
@staticmethod
def _build_pending_record(
*,
classification_name: str,
classification_description: str,
tag_name: str,
tag_description: str,
) -> OMetaTagAndClassification:
"""Compose the sink-bound create-payload for a classification + tag."""
return OMetaTagAndClassification(
fqn=None,
classification_request=CreateClassificationRequest( # pyright: ignore[reportCallIssue]
name=EntityName(classification_name),
description=Markdown(classification_description),
),
tag_request=CreateTagRequest( # pyright: ignore[reportCallIssue]
classification=FullyQualifiedEntityName(classification_name),
name=EntityName(tag_name),
description=Markdown(tag_description),
),
)

View file

@ -15,7 +15,7 @@ Defines the topology for ingesting sources
import queue
import threading
from functools import cache, singledispatchmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035
from typing import Annotated, Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035
from pydantic import BaseModel, ConfigDict, Field, create_model
@ -111,14 +111,18 @@ class TopologyNode(BaseModel):
"Each stage accepts the producer results as an argument"
),
)
children: Optional[List[str]] = Field(None, description="Nodes to execute next") # noqa: UP006, UP045
post_process: Optional[List[str]] = Field( # noqa: UP006, UP045
None, description="Method to be run after the node has been fully processed"
)
threads: bool = Field(
False,
description="Flag that defines if a node is open to MultiThreading processing.",
)
children: Annotated[
list[str] | None,
Field(description="Nodes to execute next"),
] = None
post_process: Annotated[
list[str] | None,
Field(description="Method to be run after the node has been fully processed"),
] = None
threads: Annotated[
bool,
Field(description="Flag that defines if a node is open to MultiThreading processing."),
] = False
class ServiceTopology(BaseModel):

View file

@ -14,12 +14,13 @@ Base class for ingesting database services
import traceback
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Set, Tuple # noqa: UP035
from typing import Any, Iterable, List, Optional, Set, Tuple, cast # noqa: UP035
from pydantic import BaseModel, Field
from sqlalchemy.engine import Inspector
from typing_extensions import Annotated # noqa: UP035
from metadata.domain.tags import TagCanonicalizer, TagRegistry
from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest
from metadata.generated.schema.api.data.createDatabaseSchema import (
CreateDatabaseSchemaRequest,
@ -160,6 +161,7 @@ class DatabaseServiceTopology(ServiceTopology):
"mark_schemas_as_deleted",
"mark_tables_as_deleted",
"mark_stored_procedures_as_deleted",
"clear_database_tag_scope",
],
threads=True,
)
@ -186,6 +188,7 @@ class DatabaseServiceTopology(ServiceTopology):
nullable=True,
),
],
post_process=["clear_schema_tag_scope"],
)
stored_procedure: Annotated[TopologyNode, Field(description="Stored Procedure Node")] = TopologyNode(
producer="get_stored_procedures",
@ -224,6 +227,26 @@ class DatabaseServiceSource(TopologyRunnerMixin, Source, ABC): # pylint: disabl
topology = DatabaseServiceTopology()
context = TopologyContextManager(topology)
# ``vars(self).setdefault(...)`` for thread-safe lazy init.
# See: https://docs.python.org/3/library/threadsafety.html
@property
def tags_registry(self) -> TagRegistry:
"""Per-Source registry tracking tag/classification ingestion state."""
instance_dict = vars(self)
cached = instance_dict.get("tags_registry")
if cached is not None:
return cached
return instance_dict.setdefault("tags_registry", TagRegistry(metadata=self.metadata))
@property
def tag_canonicalizer(self) -> TagCanonicalizer:
"""Per-Source canonicalizer for case-corrected tag/classification names."""
instance_dict = vars(self)
cached = instance_dict.get("tag_canonicalizer")
if cached is not None:
return cached
return instance_dict.setdefault("tag_canonicalizer", TagCanonicalizer(metadata=self.metadata))
@property
def name(self) -> str:
return self.service_connection.type.name
@ -811,6 +834,39 @@ class DatabaseServiceSource(TopologyRunnerMixin, Source, ABC): # pylint: disabl
Get the life cycle data of the table
"""
def clear_schema_tag_scope(self):
"""Drop tag-registry state for the current schema scope."""
schema_name = self.context.get().database_schema # pyright: ignore[reportAttributeAccessIssue]
if schema_name:
schema_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
schema_name=schema_name,
),
)
self.tags_registry.clear_scope(schema_fqn)
yield from ()
def clear_database_tag_scope(self):
"""Drop tag-registry state for the current database scope."""
database_name = self.context.get().database # pyright: ignore[reportAttributeAccessIssue]
if database_name:
database_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=database_name,
),
)
self.tags_registry.clear_scope(database_fqn)
yield from ()
def yield_external_table_lineage(self) -> Iterable[Either[AddLineageRequest]]:
"""
Process external table lineage

View file

@ -15,7 +15,7 @@ Snowflake source module
import json # noqa: I001
import traceback
from datetime import datetime
from typing import Iterable, List, Optional, Tuple # noqa: UP035
from typing import Iterable, List, Optional, Tuple, cast # noqa: UP035
import sqlalchemy.types as sqltypes
import sqlparse
@ -37,6 +37,7 @@ from metadata.generated.schema.entity.data.storedProcedure import (
StoredProcedureType,
)
from metadata.generated.schema.entity.data.table import (
Column,
PartitionColumnDetails,
PartitionIntervalTypes,
Table,
@ -54,7 +55,6 @@ from metadata.generated.schema.metadataIngestion.workflow import (
)
from metadata.generated.schema.type.basic import (
EntityName,
FullyQualifiedEntityName,
SourceUrl,
)
from metadata.generated.schema.type.entityReferenceList import EntityReferenceList
@ -135,7 +135,6 @@ from metadata.utils.sqlalchemy_utils import (
get_all_table_ddls,
get_all_view_definitions,
)
from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label
class MAP(StructuredType):
@ -548,9 +547,20 @@ class SnowflakeSource(
logger.debug(traceback.format_exc())
logger.error(f"Failed to fetch tags due to [{inner_exc}]")
schema_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service,
database_name=self.context.get().database,
schema_name=schema_name,
),
)
for res in result:
row = list(res)
fqn_elements = [name for name in row[2:] if name]
# row[0] = TAG_NAME, row[1] = TAG_VALUE
if not row[1]:
logger.warning(
@ -558,62 +568,113 @@ class SnowflakeSource(
"TAG_VALUE is empty. Snowflake tags require a value to be ingested."
)
continue
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(
fqn._build( # pylint: disable=protected-access
self.context.get().database_service, *fqn_elements
)
),
tags=[row[1]],
classification_name=row[0],
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
metadata=self.metadata,
system_tags=True,
)
entity_fqn = fqn._build(self.context.get().database_service, *fqn_elements) # pyright: ignore[reportAttributeAccessIssue]
try:
classification = self.tag_canonicalizer.classification(
row[0], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
)
tag = self.tag_canonicalizer.tag(
classification.name, row[1], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
)
self.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=entity_fqn,
classification_name=classification.name,
tag_name=tag.name,
classification_description=classification.description,
tag_description=tag.description,
)
except Exception as exc:
logger.debug(traceback.format_exc())
yield Either(
left=StackTraceError(
name=f"{row[0]}.{row[1]}",
error=f"Tag canonicalization failed for {row[0]}.{row[1]}: {exc}",
stackTrace=traceback.format_exc(),
),
right=None,
)
# Yield schema-level tags
if schema_name in self.schema_tags_map:
schema_fqn = fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service,
database_name=self.context.get().database,
schema_name=schema_name,
)
for tag_info in self.schema_tags_map[schema_name]:
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(schema_fqn),
tags=[tag_info["tag_value"]],
classification_name=tag_info["tag_name"],
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
metadata=self.metadata,
system_tags=True,
)
try:
classification = self.tag_canonicalizer.classification(
tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
)
tag = self.tag_canonicalizer.tag(
classification.name,
tag_info["tag_value"],
default_tag_description=SNOWFLAKE_TAG_DESCRIPTION,
)
def yield_database_tag(self, database_entity: str) -> Iterable[Either[OMetaTagAndClassification]]:
self.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=schema_fqn,
classification_name=classification.name,
tag_name=tag.name,
classification_description=classification.description,
tag_description=tag.description,
)
except Exception as exc:
logger.debug(traceback.format_exc())
yield Either(
left=StackTraceError(
name=f"{tag_info['tag_name']}.{tag_info['tag_value']}",
error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}",
stackTrace=traceback.format_exc(),
),
right=None,
)
yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndClassification]]:
"""Yield database-level tags for the topology."""
if not self.source_config.includeTags:
return
if database_entity in self.database_tags_map:
database_fqn = fqn.build(
if database_name not in self.database_tags_map:
return
database_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.get().database_service,
database_name=database_entity,
)
for tag_info in self.database_tags_map[database_entity]:
yield from get_ometa_tag_and_classification(
tag_fqn=FullyQualifiedEntityName(database_fqn),
tags=[tag_info["tag_value"]],
classification_name=tag_info["tag_name"],
tag_description=SNOWFLAKE_TAG_DESCRIPTION,
classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION,
metadata=self.metadata,
system_tags=True,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=database_name,
),
)
for tag_info in self.database_tags_map[database_name]:
try:
classification = self.tag_canonicalizer.classification(
tag_info["tag_name"], default_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION
)
tag = self.tag_canonicalizer.tag(
classification.name, tag_info["tag_value"], default_tag_description=SNOWFLAKE_TAG_DESCRIPTION
)
self.tags_registry.attach(
scope_fqn=database_fqn,
entity_fqn=database_fqn,
classification_name=classification.name,
tag_name=tag.name,
classification_description=classification.description,
tag_description=tag.description,
)
except Exception as exc:
logger.debug(traceback.format_exc())
yield Either(
left=StackTraceError(
name=f"{tag_info['tag_name']}.{tag_info['tag_value']}",
error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}",
stackTrace=traceback.format_exc(),
),
right=None,
)
yield from (Either(left=None, right=record) for record in self.tags_registry.drain())
def _get_table_names_and_types(
self, schema_name: str, table_type: TableType = TableType.Regular
@ -1049,42 +1110,72 @@ class SnowflakeSource(
return True
return False
def get_database_tag_labels(self, database_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""Return tags for the database entity from registry."""
database_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=database_name,
),
)
return self.tags_registry.labels_for(database_fqn) or None
def get_column_tag_labels(self, table_name: str, column: dict) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""Return tags for a column entity from the registry.
Column tags don't inherit from parent entities (table/schema/database)
those have separate semantic meaning at their own level. Direct
lookup is sufficient.
"""
col_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Column,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
table_name=table_name,
column_name=column["name"],
),
)
return self.tags_registry.labels_for(col_fqn) or None
def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045
"""
Return tags for schema entity including:
1. Snowflake schema-level tags
2. Inherited database-level tags (only if no tag with same classification exists)
"""
schema_tags = []
schema_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
schema_name=schema_name,
),
)
database_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
),
)
if schema_name in self.schema_tags_map:
for tag_info in self.schema_tags_map[schema_name]:
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label:
schema_tags.append(tag_label)
schema_tags = self.tags_registry.labels_for(schema_fqn)
# Add inherited database tags (only if classification doesn't already exist)
database_name = self.context.get().database
if database_name and database_name in self.database_tags_map:
for tag_info in self.database_tags_map[database_name]:
if not self._has_classification(tag_info["tag_name"], schema_tags):
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label:
schema_tags.append(tag_label)
# Include parent tags from context
parent_tags = super().get_schema_tag_labels(schema_name) or []
for tag in parent_tags:
if not self._has_classification(self._get_classification_name(tag), schema_tags):
schema_tags.append(tag)
for label in self.tags_registry.labels_for(database_fqn):
if not self._has_classification(self._get_classification_name(label), schema_tags):
schema_tags.append(label)
return schema_tags if schema_tags else None
@ -1098,32 +1189,48 @@ class SnowflakeSource(
Tag values at lower levels take precedence over inherited values.
"""
table_tags = super().get_tag_labels(table_name) or []
table_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Table,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
table_name=table_name,
skip_es_search=True,
),
)
schema_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=DatabaseSchema,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue]
),
)
database_fqn = cast(
"str",
fqn.build(
self.metadata,
entity_type=Database,
service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue]
database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue]
),
)
table_tags = self.tags_registry.labels_for(table_fqn)
# Add inherited schema tags (only if classification doesn't already exist)
schema_name = self.context.get().database_schema
if schema_name and schema_name in self.schema_tags_map:
for tag_info in self.schema_tags_map[schema_name]:
if not self._has_classification(tag_info["tag_name"], table_tags):
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label:
table_tags.append(tag_label)
for label in self.tags_registry.labels_for(schema_fqn):
if not self._has_classification(self._get_classification_name(label), table_tags):
table_tags.append(label)
# Add inherited database tags (only if classification doesn't already exist)
database_name = self.context.get().database
if database_name and database_name in self.database_tags_map:
for tag_info in self.database_tags_map[database_name]:
if not self._has_classification(tag_info["tag_name"], table_tags):
tag_label = get_tag_label(
metadata=self.metadata,
tag_name=tag_info["tag_value"],
classification_name=tag_info["tag_name"],
)
if tag_label:
table_tags.append(tag_label)
for label in self.tags_registry.labels_for(database_fqn):
if not self._has_classification(self._get_classification_name(label), table_tags):
table_tags.append(label)
return table_tags if table_tags else None

View file

View file

@ -0,0 +1,161 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``metadata.domain.tags.TagCanonicalizer``."""
from unittest.mock import MagicMock
import pytest
from metadata.domain.tags import Canonical, TagCanonicalizer
from metadata.generated.schema.entity.classification.classification import Classification
from metadata.generated.schema.type.basic import ProviderType
@pytest.fixture(autouse=True)
def _no_retry_sleep(monkeypatch: pytest.MonkeyPatch) -> None:
"""Skip tenacity's between-retry sleeps so retry-tests run instantly."""
monkeypatch.setattr("time.sleep", lambda *_args, **_kwargs: None)
@pytest.fixture
def mock_metadata() -> MagicMock:
return MagicMock()
@pytest.fixture
def canonicalizer(mock_metadata: MagicMock) -> TagCanonicalizer:
return TagCanonicalizer(metadata=mock_metadata)
def _system_classification(name: str, description: str = "") -> MagicMock:
m = MagicMock()
m.provider = ProviderType.system
m.name.root = name
if description:
m.description.root = description
else:
m.description = None
return m
def _system_tag(classification: str, name: str, description: str = "") -> MagicMock:
m = MagicMock()
m.provider = ProviderType.system
m.classification.name = classification
m.name.root = name
if description:
m.description.root = description
else:
m.description = None
return m
class TestClassification:
def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = []
result = canonicalizer.classification("MyClass", "Source desc")
assert result == Canonical(name="MyClass", description="Source desc")
def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
result = canonicalizer.classification("pii", "Source desc")
assert result == Canonical(name="PII", description="Canonical desc")
def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")]
canonicalizer.classification("pii", "Source desc")
canonicalizer.classification("PII", "Source desc")
canonicalizer.classification("Pii", "Source desc")
# Three case variants share the same case-insensitive cache key
assert mock_metadata.es_search_from_fqn.call_count == 1
def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
non_system = _system_classification("PII", "Canonical desc")
non_system.provider = ProviderType.user
mock_metadata.es_search_from_fqn.return_value = [non_system]
result = canonicalizer.classification("pii", "Source desc")
assert result == Canonical(name="pii", description="Source desc")
def test_classification_es_called_with_correct_args(
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
):
mock_metadata.es_search_from_fqn.return_value = []
canonicalizer.classification("Foo", "Source desc")
mock_metadata.es_search_from_fqn.assert_called_once_with(entity_type=Classification, fqn_search_string="Foo")
class TestTag:
def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = []
result = canonicalizer.tag("PII", "MyTag", "Source desc")
assert result == Canonical(name="MyTag", description="Source desc")
def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "Canonical desc")]
result = canonicalizer.tag("PII", "sensitive", "Source desc")
assert result == Canonical(name="Sensitive", description="Canonical desc")
def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "")]
canonicalizer.tag("PII", "sensitive", "Source desc")
canonicalizer.tag("PII", "SENSITIVE", "Source desc")
canonicalizer.tag("PII", "Sensitive", "Source desc")
# Three case variants share the same case-insensitive cache key
assert mock_metadata.es_search_from_fqn.call_count == 1
def test_match_requires_classification_match(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
# ES returns a tag but for a different classification — no canonicalization
wrong_class_tag = _system_tag("OtherClass", "Sensitive", "Canonical desc")
mock_metadata.es_search_from_fqn.return_value = [wrong_class_tag]
result = canonicalizer.tag("PII", "sensitive", "Source desc")
assert result == Canonical(name="sensitive", description="Source desc")
def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
non_system = _system_tag("PII", "Sensitive", "Canonical desc")
non_system.provider = ProviderType.user
mock_metadata.es_search_from_fqn.return_value = [non_system]
result = canonicalizer.tag("PII", "sensitive", "Source desc")
assert result == Canonical(name="sensitive", description="Source desc")
class TestRetryAndFailure:
def test_transient_failure_recovers_within_retry_budget(
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
):
# First two ES calls raise; third succeeds.
mock_metadata.es_search_from_fqn.side_effect = [
RuntimeError("transient 1"),
RuntimeError("transient 2"),
[_system_classification("PII", "Canonical desc")],
]
result = canonicalizer.classification("pii", "Source desc")
assert result == Canonical(name="PII", description="Canonical desc")
assert mock_metadata.es_search_from_fqn.call_count == 3
def test_persistent_failure_raises_after_retries_exhaust(
self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock
):
mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent")
with pytest.raises(RuntimeError, match="persistent"):
canonicalizer.classification("MyClass", "Source desc")
assert mock_metadata.es_search_from_fqn.call_count == 5
def test_persistent_failure_does_not_poison_cache(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock):
# First call: ES persistently fails -> raises.
mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent")
with pytest.raises(RuntimeError):
canonicalizer.classification("MyClass", "Source desc")
# ES recovers; subsequent call must reach ES again, not return a cached fallback.
mock_metadata.es_search_from_fqn.side_effect = None
mock_metadata.es_search_from_fqn.return_value = [_system_classification("MyClass", "Canonical desc")]
result = canonicalizer.classification("MyClass", "Source desc")
assert result == Canonical(name="MyClass", description="Canonical desc")

View file

@ -0,0 +1,375 @@
# Copyright 2025 Collate
# Licensed under the Collate Community License, Version 1.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``metadata.domain.tags.TagRegistry``.
Covers attach/labels_for/drain/clear_scope/ensure_known semantics plus
basic thread-safety stress scenarios. The OM client is mocked; no
network or schema validation against a real backend.
"""
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import MagicMock
import pytest
from metadata.domain.tags import ScopeAlreadyClearedError, TagRegistry
from metadata.generated.schema.type.tagLabel import LabelType, State
@pytest.fixture
def mock_metadata() -> MagicMock:
return MagicMock()
@pytest.fixture
def registry(mock_metadata: MagicMock) -> TagRegistry:
return TagRegistry(metadata=mock_metadata)
def _attach_kwargs(
scope: str,
entity: str,
classification: str = "TestClass",
tag: str = "TestTag",
) -> dict:
return {
"scope_fqn": scope,
"entity_fqn": entity,
"classification_name": classification,
"tag_name": tag,
"classification_description": "test classification",
"tag_description": "test tag",
}
class TestAttachAndLabelsFor:
def test_attach_then_labels_for_returns_one_label(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
labels = registry.labels_for("svc.db.schema.table")
assert len(labels) == 1
def test_attach_multiple_tags_same_entity_returns_all(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag1"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag2"))
labels = registry.labels_for("svc.db.schema.table")
assert len(labels) == 2
def test_labels_for_unattached_entity_returns_empty_list(self, registry: TagRegistry):
assert registry.labels_for("svc.db.schema.unknown") == []
def test_labels_for_is_idempotent(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
first = registry.labels_for("svc.db.schema.table")
second = registry.labels_for("svc.db.schema.table")
# Read-and-leave: both reads return the same labels.
# Cleanup is the responsibility of clear_scope, not labels_for.
assert len(first) == 1
assert second == first
def test_labels_for_returns_copy_not_internal_list(self, registry: TagRegistry):
# Mutating the returned list must not affect registry state.
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table"))
first = registry.labels_for("svc.db.schema.table")
first.clear()
second = registry.labels_for("svc.db.schema.table")
assert len(second) == 1
class TestDrain:
def test_drain_yields_pending_then_clears(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_a"))
first = list(registry.drain())
second = list(registry.drain())
assert len(first) == 1
assert second == []
def test_drain_dedupes_same_tag_across_entities(self, registry: TagRegistry):
for i in range(100):
registry.attach(**_attach_kwargs("svc.db", f"svc.db.schema.tbl_{i}"))
pending = list(registry.drain())
assert len(pending) == 1
def test_drain_yields_distinct_payloads_for_distinct_tags(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1", tag="TagA"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2", tag="TagB"))
pending = list(registry.drain())
assert len(pending) == 2
def test_drain_does_not_dedup_across_case_variants(self, registry: TagRegistry):
# OM stores tags case-sensitively; our dedup must follow that rule.
registry.attach(**_attach_kwargs("svc.db", "svc.db.t1", tag="Sensitive"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.t2", tag="sensitive"))
pending = list(registry.drain())
assert len(pending) == 2 # both must PUT — they're distinct tags server-side
def test_drain_dedupes_same_fqn_across_label_types(self, registry: TagRegistry):
# Different cache keys (label_type varies) but identical tag_fqn → ONE PUT.
# Cache key is (class, tag, label_type, state); tag_fqn is class.tag.
registry.attach(
**_attach_kwargs("svc.db", "svc.db.t1"),
label_type=LabelType.Manual,
)
registry.attach(
**_attach_kwargs("svc.db", "svc.db.t2"),
label_type=LabelType.Automated,
)
pending = list(registry.drain())
assert len(pending) == 1, "fqn-level dedup must collapse PUTs across label_type variants"
class TestClearScope:
def test_clear_scope_drops_descendant_labels(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_1"))
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_2"))
registry.clear_scope("svc.db.schema")
assert registry.labels_for("svc.db.schema.tbl_1") == []
assert registry.labels_for("svc.db.schema.tbl_2") == []
def test_clear_scope_drops_scope_itself(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema"))
registry.clear_scope("svc.db.schema")
assert registry.labels_for("svc.db.schema") == []
def test_clear_scope_preserves_other_scopes(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db.schema_a", "svc.db.schema_a.tbl"))
registry.attach(**_attach_kwargs("svc.db.schema_b", "svc.db.schema_b.tbl"))
registry.clear_scope("svc.db.schema_a")
assert registry.labels_for("svc.db.schema_a.tbl") == []
assert len(registry.labels_for("svc.db.schema_b.tbl")) == 1
def test_clear_scope_no_false_prefix_match(self, registry: TagRegistry):
# 'schema_a' is NOT a prefix of 'schema_alpha' once the FQN
# separator is taken into account.
registry.attach(**_attach_kwargs("svc.db.schema_alpha", "svc.db.schema_alpha.tbl"))
registry.clear_scope("svc.db.schema_a")
assert len(registry.labels_for("svc.db.schema_alpha.tbl")) == 1
def test_clear_scope_idempotent_on_unattached_scope(self, registry: TagRegistry):
registry.clear_scope("svc.db.never_attached") # must not raise
def test_attach_after_clear_raises(self, registry: TagRegistry):
registry.clear_scope("svc.db.schema")
with pytest.raises(ScopeAlreadyClearedError):
registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl"))
class TestEnsureKnown:
def test_is_known_empty_returns_false(self, registry: TagRegistry):
assert registry.is_known("Class.Tag") is False
def test_is_known_after_attach_returns_true(self, registry: TagRegistry):
registry.attach(
**_attach_kwargs(
"svc.db",
"svc.db.schema.tbl",
classification="Class",
tag="Tag",
)
)
assert registry.is_known("Class.Tag") is True
def test_is_known_is_case_sensitive(self, registry: TagRegistry):
# Reflects OM's case-sensitive identity rule.
registry.attach(
**_attach_kwargs(
"svc.db",
"svc.db.schema.tbl",
classification="Class",
tag="Tag",
)
)
assert registry.is_known("Class.Tag") is True
assert registry.is_known("class.tag") is False # different tag server-side
def test_ensure_known_cache_hit_skips_io(self, registry: TagRegistry, mock_metadata: MagicMock):
registry.attach(
**_attach_kwargs(
"svc.db",
"svc.db.schema.tbl",
classification="Class",
tag="Tag",
)
)
assert registry.ensure_known("Class.Tag") is True
mock_metadata.get_by_name.assert_not_called()
def test_ensure_known_cache_miss_calls_get_by_name_once(self, registry: TagRegistry, mock_metadata: MagicMock):
mock_metadata.get_by_name.return_value = MagicMock()
assert registry.ensure_known("Other.Tag") is True
assert registry.ensure_known("Other.Tag") is True # cached now
assert mock_metadata.get_by_name.call_count == 1
def test_ensure_known_404_returns_false_and_does_not_cache(self, registry: TagRegistry, mock_metadata: MagicMock):
mock_metadata.get_by_name.return_value = None
assert registry.ensure_known("Missing.Tag") is False
assert registry.ensure_known("Missing.Tag") is False
# Re-queries on each miss; not cached.
assert mock_metadata.get_by_name.call_count == 2
def test_ensure_known_swallows_exception(self, registry: TagRegistry, mock_metadata: MagicMock):
mock_metadata.get_by_name.side_effect = RuntimeError("network down")
assert registry.ensure_known("Crashed.Tag") is False
class TestThreadSafety:
def test_concurrent_attach_same_tag_dedupes_pending(self, registry: TagRegistry):
def worker(thread_idx: int) -> None:
for i in range(100):
registry.attach(
**_attach_kwargs(
"svc.db",
f"svc.db.schema.tbl_{thread_idx}_{i}",
)
)
with ThreadPoolExecutor(max_workers=8) as pool:
list(pool.map(worker, range(8)))
pending = list(registry.drain())
assert len(pending) == 1
def test_concurrent_disjoint_scopes_no_label_loss(self, registry: TagRegistry):
def worker(scope_idx: int) -> None:
scope = f"svc.db.schema_{scope_idx}"
for i in range(50):
registry.attach(
**_attach_kwargs(
scope,
f"{scope}.tbl_{i}",
tag=f"Tag_{scope_idx}_{i}",
)
)
with ThreadPoolExecutor(max_workers=8) as pool:
list(pool.map(worker, range(8)))
for scope_idx in range(8):
scope = f"svc.db.schema_{scope_idx}"
for i in range(50):
entity = f"{scope}.tbl_{i}"
labels = registry.labels_for(entity)
assert len(labels) == 1, f"missing label for {entity}"
class TestStats:
def test_initial_stats_all_zero(self, registry: TagRegistry):
assert registry.stats() == {
"known_tag_fqns": 0,
"tag_label_cache": 0,
"pending": 0,
"cleared_scopes": 0,
"live_entities": 0,
"live_labels": 0,
}
def test_stats_reflect_attach(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"))
s = registry.stats()
# Both attaches share the same tag — known + pending dedup to 1
assert s["known_tag_fqns"] == 1
assert s["pending"] == 1
# Two entities, each with one label
assert s["live_entities"] == 2
assert s["live_labels"] == 2
def test_labels_for_does_not_decrease_live_state(self, registry: TagRegistry):
# labels_for is idempotent (read-and-leave); clear_scope is the
# only mechanism that reduces live state.
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl"))
registry.labels_for("svc.db.schema.tbl")
s = registry.stats()
assert s["live_entities"] == 1
assert s["live_labels"] == 1
assert s["known_tag_fqns"] == 1
assert s["pending"] == 1
def test_drain_decreases_pending_only(self, registry: TagRegistry):
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl"))
list(registry.drain())
s = registry.stats()
assert s["pending"] == 0
assert s["known_tag_fqns"] == 1 # still tracked for dedup
def test_clear_scope_zeroes_live_state_for_scope(self, registry: TagRegistry):
# Critical invariant: after clear_scope, no live_entities for that scope.
for i in range(50):
registry.attach(**_attach_kwargs("svc.db.schema", f"svc.db.schema.tbl_{i}"))
assert registry.stats()["live_entities"] == 50
registry.clear_scope("svc.db.schema")
s = registry.stats()
assert s["live_entities"] == 0
assert s["live_labels"] == 0
assert s["cleared_scopes"] == 1
class TestInterning:
"""TagLabel interning — multiple attaches with the same key share one
underlying ``TagLabel`` instance. Memory bound depends on this; the
`is`-identity assertion is the load-bearing check."""
def test_attach_interns_identical_tag_labels(self, registry: TagRegistry):
# Same (classification, tag, label_type, state) across two entities
# must return the exact same TagLabel object — not just an equal one.
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"))
registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"))
label_1 = registry.labels_for("svc.db.schema.tbl_1")[0]
label_2 = registry.labels_for("svc.db.schema.tbl_2")[0]
assert label_1 is label_2, "expected shared TagLabel instance via interning"
def test_attach_does_not_intern_across_label_types(self, registry: TagRegistry):
# Cache key includes label_type — non-default values must not collide.
registry.attach(
**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"),
label_type=LabelType.Manual,
)
registry.attach(
**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"),
label_type=LabelType.Automated,
)
label_manual = registry.labels_for("svc.db.schema.tbl_1")[0]
label_auto = registry.labels_for("svc.db.schema.tbl_2")[0]
assert label_manual is not label_auto
assert label_manual.labelType == LabelType.Manual
assert label_auto.labelType == LabelType.Automated
def test_attach_does_not_intern_across_states(self, registry: TagRegistry):
registry.attach(
**_attach_kwargs("svc.db", "svc.db.schema.tbl_1"),
state=State.Suggested,
)
registry.attach(
**_attach_kwargs("svc.db", "svc.db.schema.tbl_2"),
state=State.Confirmed,
)
label_suggested = registry.labels_for("svc.db.schema.tbl_1")[0]
label_confirmed = registry.labels_for("svc.db.schema.tbl_2")[0]
assert label_suggested is not label_confirmed
def test_intern_cache_survives_clear_scope(self, registry: TagRegistry):
# Cache lifetime is registry lifetime, NOT scope lifetime — next scope
# reuses the same TagLabel instance for the same (class, tag, ...) key.
registry.attach(**_attach_kwargs("svc.db.schema_1", "svc.db.schema_1.tbl"))
label_first = registry.labels_for("svc.db.schema_1.tbl")[0]
registry.clear_scope("svc.db.schema_1")
registry.attach(**_attach_kwargs("svc.db.schema_2", "svc.db.schema_2.tbl"))
label_second = registry.labels_for("svc.db.schema_2.tbl")[0]
assert label_first is label_second, "intern cache should survive clear_scope"

View file

@ -19,7 +19,9 @@ from unittest.mock import MagicMock, Mock, PropertyMock, patch
import sqlalchemy.types as sqltypes
from metadata.generated.schema.entity.data.table import TableType
from metadata.generated.schema.entity.data.database import Database
from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema
from metadata.generated.schema.entity.data.table import Table, TableType
from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import (
PipelineStatus,
)
@ -27,14 +29,9 @@ from metadata.generated.schema.metadataIngestion.workflow import (
OpenMetadataWorkflowConfig,
)
from metadata.generated.schema.type.filterPattern import FilterPattern
from metadata.generated.schema.type.tagLabel import (
LabelType,
State,
TagLabel,
TagSource,
)
from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource
from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure
from metadata.utils import fqn
SNOWFLAKE_CONFIGURATION = {
"source": {
@ -491,18 +488,39 @@ class SnowflakeUnitTest(TestCase):
self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default
self.assertFalse(map_type.not_null) # default
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
def test_schema_tag_inheritance(
self,
mock_get_tag_label,
mock_parent_get_schema_tag_labels,
mock_parent_get_tag_labels,
):
"""Test schema tag inheritance"""
def _setup_tag_context(self, source, service_name="local_snowflake"):
"""Populate the topology context for schema-stage tag tests and return the FQN trio."""
source.context.get().__dict__["database_service"] = service_name
source.context.get().__dict__["database"] = "TEST_DATABASE"
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
database_fqn = fqn.build(
source.metadata,
entity_type=Database,
service_name=service_name,
database_name="TEST_DATABASE",
)
schema_fqn = fqn.build(
source.metadata,
entity_type=DatabaseSchema,
service_name=service_name,
database_name="TEST_DATABASE",
schema_name="TEST_SCHEMA",
)
table_fqn = fqn.build(
source.metadata,
entity_type=Table,
service_name=service_name,
database_name="TEST_DATABASE",
schema_name="TEST_SCHEMA",
table_name="TEST_TABLE",
skip_es_search=True,
)
return database_fqn, schema_fqn, table_fqn
def test_schema_tag_inheritance(self):
"""Schema tags propagate to tables; classification dedup is preserved."""
for source in self.sources.values():
# Verify tags are fetched and stored
mock_schema_tags = [
Mock(SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE"),
]
@ -519,48 +537,39 @@ class SnowflakeUnitTest(TestCase):
{"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"},
)
# Verify schema tag labels
mock_get_tag_label.return_value = TagLabel(
tagFQN="SnowflakeTag.SCHEMA_TAG",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
_, schema_fqn, table_fqn = self._setup_tag_context(source)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=schema_fqn,
classification_name="SCHEMA_CLASSIFICATION",
tag_name="SCHEMA_TAG",
classification_description="",
tag_description="",
)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=table_fqn,
classification_name="TABLE_CLASSIFICATION",
tag_name="TABLE_TAG",
classification_description="",
tag_description="",
)
mock_parent_get_schema_tag_labels.return_value = None
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
self.assertIsNotNone(schema_labels)
self.assertEqual(len(schema_labels), 1)
# Verify tag inheritance
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
mock_parent_get_tag_labels.return_value = [
TagLabel(
tagFQN="SnowflakeTag.TABLE_TAG",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
]
self.assertEqual(schema_labels[0].tagFQN.root, "SCHEMA_CLASSIFICATION.SCHEMA_TAG")
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
self.assertEqual(len(table_labels), 2)
tag_fqns = [tag.tagFQN.root for tag in table_labels]
self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns)
self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns)
self.assertIn("SCHEMA_CLASSIFICATION.SCHEMA_TAG", tag_fqns)
self.assertIn("TABLE_CLASSIFICATION.TABLE_TAG", tag_fqns)
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
def test_database_tag_inheritance(
self,
mock_get_tag_label,
mock_parent_get_schema_tag_labels,
mock_parent_get_tag_labels,
):
"""Test database tag inheritance to schemas and tables"""
def test_database_tag_inheritance(self):
"""Database tags propagate to schemas and tables when classifications don't overlap."""
for source in self.sources.values():
# Setup mock database tags
mock_database_tags = [
Mock(
DATABASE_NAME="TEST_DATABASE",
@ -574,7 +583,6 @@ class SnowflakeUnitTest(TestCase):
source.engine.connect.return_value.__enter__ = MagicMock(return_value=mock_conn)
source.engine.connect.return_value.__exit__ = MagicMock(return_value=False)
# Test set_database_tags_map
source.set_database_tags_map("TEST_DATABASE")
self.assertEqual(len(source.database_tags_map["TEST_DATABASE"]), 1)
self.assertEqual(
@ -582,23 +590,33 @@ class SnowflakeUnitTest(TestCase):
{"tag_name": "DATABASE_TAG", "tag_value": "DB_VALUE"},
)
# Setup schema tags for combined testing
source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "SCHEMA_TAG", "tag_value": "SCHEMA_VALUE"}]}
database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source)
# Mock tag label creation
def mock_tag_label_side_effect(metadata, tag_name, classification_name):
return TagLabel(
tagFQN=f"{classification_name}.{tag_name}",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
source.tags_registry.attach(
scope_fqn=database_fqn,
entity_fqn=database_fqn,
classification_name="DATABASE_TAG",
tag_name="DB_VALUE",
classification_description="",
tag_description="",
)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=schema_fqn,
classification_name="SCHEMA_TAG",
tag_name="SCHEMA_VALUE",
classification_description="",
tag_description="",
)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=table_fqn,
classification_name="TABLE_TAG",
tag_name="TABLE_VALUE",
classification_description="",
tag_description="",
)
mock_get_tag_label.side_effect = mock_tag_label_side_effect
mock_parent_get_schema_tag_labels.return_value = None
# Test schema inherits database tags
source.context.get().__dict__["database"] = "TEST_DATABASE"
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
self.assertIsNotNone(schema_labels)
self.assertEqual(len(schema_labels), 2)
@ -606,17 +624,6 @@ class SnowflakeUnitTest(TestCase):
self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns)
self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns)
# Test table inherits both schema and database tags
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
mock_parent_get_tag_labels.return_value = [
TagLabel(
tagFQN="TABLE_TAG.TABLE_VALUE",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
]
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
self.assertEqual(len(table_labels), 3)
tag_fqns = [tag.tagFQN.root for tag in table_labels]
@ -624,59 +631,44 @@ class SnowflakeUnitTest(TestCase):
self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns)
self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns)
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels")
@patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels")
@patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label")
def test_tag_value_precedence(
self,
mock_get_tag_label,
mock_parent_get_schema_tag_labels,
mock_parent_get_tag_labels,
):
"""Test that tag values at lower levels take precedence over inherited values.
def test_tag_value_precedence(self):
"""Lower-level tags override inherited values for the same classification.
When database, schema, and table all have the same tag name (classification)
but different values, the object's own value should take precedence.
Database: ENV=dev, Schema: ENV=staging, Table: ENV=production.
Schema lookup must return only ENV.staging; table lookup only ENV.production.
"""
for source in self.sources.values():
# Setup: Database, schema, and table all have ENV tag with different values
# Database: ENV=dev
# Schema: ENV=staging
# Table: ENV=production
database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source)
source.database_tags_map = {"TEST_DATABASE": [{"tag_name": "ENV", "tag_value": "dev"}]}
source.tags_registry.attach(
scope_fqn=database_fqn,
entity_fqn=database_fqn,
classification_name="ENV",
tag_name="dev",
classification_description="",
tag_description="",
)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=schema_fqn,
classification_name="ENV",
tag_name="staging",
classification_description="",
tag_description="",
)
source.tags_registry.attach(
scope_fqn=schema_fqn,
entity_fqn=table_fqn,
classification_name="ENV",
tag_name="production",
classification_description="env classification",
tag_description="production tag",
)
source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "ENV", "tag_value": "staging"}]}
def mock_tag_label_side_effect(metadata, tag_name, classification_name):
return TagLabel(
tagFQN=f"{classification_name}.{tag_name}",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
mock_get_tag_label.side_effect = mock_tag_label_side_effect
mock_parent_get_schema_tag_labels.return_value = None
source.context.get().__dict__["database"] = "TEST_DATABASE"
source.context.get().__dict__["database_schema"] = "TEST_SCHEMA"
# Test schema level: schema's own value takes precedence over database
schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA")
self.assertEqual(len(schema_labels), 1)
self.assertEqual(schema_labels[0].tagFQN.root, "ENV.staging")
# Test table level: table's own value takes precedence over schema and database
mock_parent_get_tag_labels.return_value = [
TagLabel(
tagFQN="ENV.production",
labelType=LabelType.Automated,
state=State.Suggested,
source=TagSource.Classification,
)
]
table_labels = source.get_tag_labels(table_name="TEST_TABLE")
self.assertEqual(len(table_labels), 1)
self.assertEqual(table_labels[0].tagFQN.root, "ENV.production")